@@ -337,6 +337,16 @@ def __init__(self, **kwargs):
337337 self .src_partitions = None
338338 self .linear_modules = [torch .nn .functional .linear , torch .nn .Linear ]
339339
340+ def _get_weight_deps (
341+ self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
342+ ) -> Tuple [bool , List [torch .fx .Node ]]:
343+ if precision == ConfigPrecisionType .FP32 and self .force_fp32_dynamic_linear :
344+ # if force fp32_dynamic_linear is on and we detected this as fp32, then we
345+ # do not partition the weight node
346+ return (True , [])
347+
348+ return super ()._get_weight_deps (node , ep , precision )
349+
340350 def get_deps (
341351 self ,
342352 node : torch .fx .Node ,
@@ -436,6 +446,16 @@ def __init__(self, **kwargs):
436446 self .weight_idx = 1
437447 self .act_idx = 0
438448
449+ def _get_weight_deps (
450+ self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
451+ ) -> Tuple [bool , List [torch .fx .Node ]]:
452+ if precision == ConfigPrecisionType .FP32 and self .force_fp32_dynamic_linear :
453+ # if force fp32_dynamic_linear is on and we detected this as fp32, then we
454+ # do not partition the weight node
455+ return (True , [])
456+
457+ return super ()._get_weight_deps (node , ep , precision )
458+
439459 def supported_precision_types (self ):
440460 return [
441461 ConfigPrecisionType .FP32 ,
0 commit comments