@@ -125,6 +125,7 @@ def get_deps(
125125 # detected precision but it is either disabled or not supported
126126 why (node , f"Unsupported precision type { precision } " )
127127 return (False , [])
128+ _ , precision = self ._overwrite_precision (node )
128129 valid_bias , bias_deps = self ._get_bias_deps (node , ep , precision )
129130 valid_weight , weight_deps = self ._get_weight_deps (node , ep , precision )
130131 valid_act , act_deps = self ._get_act_deps (node , ep , precision )
@@ -139,11 +140,6 @@ def _get_weight_deps(
139140 self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
140141 ) -> Tuple [bool , List [torch .fx .Node ]]:
141142 gemm_deps = []
142- if precision == ConfigPrecisionType .FP32 and self .force_non_static_weights_for_f32_linear :
143- # if force_non_static_weights_for_f32_linear is enabled, then we
144- # do not partition the weight node
145- return (True , gemm_deps )
146-
147143 if precision == ConfigPrecisionType .FP32 :
148144 # First find the weight
149145 weight_node = get_input_node (node , self .weight_idx )
@@ -225,8 +221,11 @@ def _get_bias_deps(
225221 self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
226222 ) -> Tuple [bool , List [torch .fx .Node ]]:
227223 gemm_deps = []
228- if precision == ConfigPrecisionType .FP32 and self .force_non_static_weights_for_f32_linear :
229- # if force for_fp32_linear_as_matmul is enabled, then we
224+ if (
225+ precision == ConfigPrecisionType .FP32
226+ and self .force_non_static_weights_for_f32_linear
227+ ):
228+ # if force_non_static_weights_for_f32_linear is enabled, then we
230229 # do not partition the weight node
231230 return (True , gemm_deps )
232231
@@ -304,6 +303,14 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
304303 def _get_weight_deps (
305304 self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
306305 ) -> Tuple [bool , List [torch .fx .Node ]]:
306+ if (
307+ precision == ConfigPrecisionType .FP32
308+ and self .force_non_static_weights_for_f32_linear
309+ ):
310+ # if force_non_static_weights_for_f32_linear is enabled, then we
311+ # do not partition the weight node
312+ return (True , [])
313+
307314 # Since we are in Linear, we may assume that the weights are indeed static.
308315 overwritten_linear_precision , new_precision = self ._overwrite_precision (node )
309316 if new_precision == ConfigPrecisionType .FP32 and overwritten_linear_precision :
@@ -403,6 +410,19 @@ def __init__(self, **kwargs):
403410 self .src_partitions = None
404411 self .linear_modules = [torch .nn .functional .linear , torch .nn .Linear ]
405412
413+ def _get_weight_deps (
414+ self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
415+ ) -> Tuple [bool , List [torch .fx .Node ]]:
416+ if (
417+ precision == ConfigPrecisionType .FP32
418+ and self .force_non_static_weights_for_f32_linear
419+ ):
420+ # if force_non_static_weights_for_f32_linear is on and we detected this as fp32, then we
421+ # do not partition the weight node
422+ return (True , [])
423+
424+ return super ()._get_weight_deps (node , ep , precision )
425+
406426 def get_deps (
407427 self ,
408428 node : torch .fx .Node ,
@@ -511,6 +531,19 @@ def __init__(self, **kwargs):
511531 self .weight_idx = 1
512532 self .act_idx = 0
513533
534+ def _get_weight_deps (
535+ self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
536+ ) -> Tuple [bool , List [torch .fx .Node ]]:
537+ if (
538+ precision == ConfigPrecisionType .FP32
539+ and self .force_non_static_weights_for_f32_linear
540+ ):
541+ # if force_non_static_weights_for_f32_linear is on and we detected this as fp32, then we
542+ # do not partition the weight node
543+ return (True , [])
544+
545+ return super ()._get_weight_deps (node , ep , precision )
546+
514547 def supported_precision_types (self ):
515548 return [
516549 ConfigPrecisionType .FP32 ,
0 commit comments