@@ -96,9 +96,9 @@ def _detect_precision(self, node: torch.fx.Node) -> ConfigPrecisionType:
9696 def _overwrite_precision (self , node : torch .fx .Node ):
9797 precision = self ._detect_precision (node )
9898 if precision not in self .enabled_precision_types :
99- # detected precision is not enabled, lets try to partition it as fp32
99+ # detected precision is not enabled, try to partition it as fp32
100100 if self .enabled_precision_types == [ConfigPrecisionType .FP32 ]:
101- # if only fp32 is enabled, then we can still partition fp32 gemms
101+ # when only fp32 is enabled, then we can still partition fp32 gemms
102102 # even with in a quantized graph
103103 if precision in [
104104 ConfigPrecisionType .STATIC_QUANT ,
@@ -107,6 +107,7 @@ def _overwrite_precision(self, node: torch.fx.Node):
107107 precision = ConfigPrecisionType .FP32
108108 logging .info (f"Overwriting precision, partitioning { node } as FP32" )
109109 return True , precision
110+
110111 return False , precision
111112
112113 def get_deps (
@@ -220,8 +221,11 @@ def _get_bias_deps(
220221 self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
221222 ) -> Tuple [bool , List [torch .fx .Node ]]:
222223 gemm_deps = []
223- if precision == ConfigPrecisionType .FP32 and self .force_fp32_dynamic_linear :
224- # if force force_fp32_dynamic_linear is enabled, then we
224+ if (
225+ precision == ConfigPrecisionType .FP32
226+ and self .force_non_static_weights_for_f32_linear
227+ ):
228+ # if force_non_static_weights_for_f32_linear is enabled, then we
225229 # do not partition the weight node
226230 return (True , gemm_deps )
227231
@@ -299,8 +303,11 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
299303 def _get_weight_deps (
300304 self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
301305 ) -> Tuple [bool , List [torch .fx .Node ]]:
302- if precision == ConfigPrecisionType .FP32 and self .force_fp32_dynamic_linear :
303- # if force fp32_dynamic_linear is enabled, then we
306+ if (
307+ precision == ConfigPrecisionType .FP32
308+ and self .force_non_static_weights_for_f32_linear
309+ ):
310+ # if force_non_static_weights_for_f32_linear is enabled, then we
304311 # do not partition the weight node
305312 return (True , [])
306313
@@ -406,9 +413,11 @@ def __init__(self, **kwargs):
406413 def _get_weight_deps (
407414 self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
408415 ) -> Tuple [bool , List [torch .fx .Node ]]:
409- # TODO(maxren, T210537195):
410- if precision == ConfigPrecisionType .FP32 and self .force_fp32_dynamic_linear :
411- # if force fp32_dynamic_linear is on and we detected this as fp32, then we
416+ if (
417+ precision == ConfigPrecisionType .FP32
418+ and self .force_non_static_weights_for_f32_linear
419+ ):
420+ # if force_non_static_weights_for_f32_linear is on and we detected this as fp32, then we
412421 # do not partition the weight node
413422 return (True , [])
414423
@@ -495,11 +504,11 @@ def find_partition_args(input_node):
495504 node .args = old_args
496505 node .users = old_users
497506
498- # When using force_fp32_dynamic_linear , we want to get_deps to overwrite the source partition nodes.
507+ # When using force_non_static_weights_for_f32_linear , we want to get_deps to overwrite the source partition nodes.
499508 # Else we want to be greedy.
500509 ret_deps = (
501510 list (set (deps ) & set (src_partition .nodes ))
502- if self .force_fp32_dynamic_linear
511+ if self .force_non_static_weights_for_f32_linear
503512 else list (set (deps ) | set (src_partition .nodes ))
504513 )
505514
@@ -525,8 +534,11 @@ def __init__(self, **kwargs):
525534 def _get_weight_deps (
526535 self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
527536 ) -> Tuple [bool , List [torch .fx .Node ]]:
528- if precision == ConfigPrecisionType .FP32 and self .force_fp32_dynamic_linear :
529- # if force fp32_dynamic_linear is on and we detected this as fp32, then we
537+ if (
538+ precision == ConfigPrecisionType .FP32
539+ and self .force_non_static_weights_for_f32_linear
540+ ):
541+ # if force_non_static_weights_for_f32_linear is on and we detected this as fp32, then we
530542 # do not partition the weight node
531543 return (True , [])
532544
0 commit comments