@@ -96,9 +96,9 @@ def _detect_precision(self, node: torch.fx.Node) -> ConfigPrecisionType:
9696 def _overwrite_precision (self , node : torch .fx .Node ):
9797 precision = self ._detect_precision (node )
9898 if precision not in self .enabled_precision_types :
99- # detected precision is not enabled, lets try to partition it as fp32
99+ # detected precision is not enabled, try to partition it as fp32
100100 if self .enabled_precision_types == [ConfigPrecisionType .FP32 ]:
101- # if only fp32 is enabled, then we can still partition fp32 gemms
101+ # when only fp32 is enabled, then we can still partition fp32 gemms
102102 # even with in a quantized graph
103103 if precision in [
104104 ConfigPrecisionType .STATIC_QUANT ,
@@ -107,6 +107,7 @@ def _overwrite_precision(self, node: torch.fx.Node):
107107 precision = ConfigPrecisionType .FP32
108108 logging .info (f"Overwriting precision, partitioning { node } as FP32" )
109109 return True , precision
110+
110111 return False , precision
111112
112113 def get_deps (
@@ -226,8 +227,11 @@ def _get_bias_deps(
226227 self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
227228 ) -> Tuple [bool , List [torch .fx .Node ]]:
228229 gemm_deps = []
229- if precision == ConfigPrecisionType .FP32 and self .force_fp32_dynamic_linear :
230- # if force force_fp32_dynamic_linear is enabled, then we
230+ if (
231+ precision == ConfigPrecisionType .FP32
232+ and self .force_non_static_weights_for_f32_linear
233+ ):
234+ # if force_non_static_weights_for_f32_linear is enabled, then we
231235 # do not partition the weight node
232236 return (True , gemm_deps )
233237
@@ -305,8 +309,11 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
305309 def _get_weight_deps (
306310 self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
307311 ) -> Tuple [bool , List [torch .fx .Node ]]:
308- if precision == ConfigPrecisionType .FP32 and self .force_fp32_dynamic_linear :
309- # if force fp32_dynamic_linear is enabled, then we
312+ if (
313+ precision == ConfigPrecisionType .FP32
314+ and self .force_non_static_weights_for_f32_linear
315+ ):
316+ # if force_non_static_weights_for_f32_linear is enabled, then we
310317 # do not partition the weight node
311318 return (True , [])
312319
@@ -412,9 +419,11 @@ def __init__(self, **kwargs):
412419 def _get_weight_deps (
413420 self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
414421 ) -> Tuple [bool , List [torch .fx .Node ]]:
415- # TODO(maxren, T210537195):
416- if precision == ConfigPrecisionType .FP32 and self .force_fp32_dynamic_linear :
417- # if force fp32_dynamic_linear is on and we detected this as fp32, then we
422+ if (
423+ precision == ConfigPrecisionType .FP32
424+ and self .force_non_static_weights_for_f32_linear
425+ ):
426+ # if force_non_static_weights_for_f32_linear is on and we detected this as fp32, then we
418427 # do not partition the weight node
419428 return (True , [])
420429
@@ -501,11 +510,11 @@ def find_partition_args(input_node):
501510 node .args = old_args
502511 node .users = old_users
503512
504- # When using force_fp32_dynamic_linear , we want to get_deps to overwrite the source partition nodes.
513+ # When using force_non_static_weights_for_f32_linear , we want to get_deps to overwrite the source partition nodes.
505514 # Else we want to be greedy.
506515 ret_deps = (
507516 list (set (deps ) & set (src_partition .nodes ))
508- if self .force_fp32_dynamic_linear
517+ if self .force_non_static_weights_for_f32_linear
509518 else list (set (deps ) | set (src_partition .nodes ))
510519 )
511520
@@ -531,8 +540,11 @@ def __init__(self, **kwargs):
531540 def _get_weight_deps (
532541 self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
533542 ) -> Tuple [bool , List [torch .fx .Node ]]:
534- if precision == ConfigPrecisionType .FP32 and self .force_fp32_dynamic_linear :
535- # if force fp32_dynamic_linear is on and we detected this as fp32, then we
543+ if (
544+ precision == ConfigPrecisionType .FP32
545+ and self .force_non_static_weights_for_f32_linear
546+ ):
547+ # if force_non_static_weights_for_f32_linear is on and we detected this as fp32, then we
536548 # do not partition the weight node
537549 return (True , [])
538550
0 commit comments