@@ -97,9 +97,9 @@ def _detect_precision(self, node: torch.fx.Node) -> ConfigPrecisionType:
9797 def _overwrite_precision (self , node : torch .fx .Node ):
9898 precision = self ._detect_precision (node )
9999 if precision not in self .enabled_precision_types :
100- # detected precision is not enabled, lets try to partition it as fp32
100+ # detected precision is not enabled, try to partition it as fp32
101101 if self .enabled_precision_types == [ConfigPrecisionType .FP32 ]:
102- # if only fp32 is enabled, then we can still partition fp32 gemms
102+ # when only fp32 is enabled, then we can still partition fp32 gemms
103103 # even with in a quantized graph
104104 if precision in [
105105 ConfigPrecisionType .STATIC_QUANT ,
@@ -108,6 +108,7 @@ def _overwrite_precision(self, node: torch.fx.Node):
108108 precision = ConfigPrecisionType .FP32
109109 logging .info (f"Overwriting precision, partitioning { node } as FP32" )
110110 return True , precision
111+
111112 return False , precision
112113
113114 def get_deps (
@@ -210,8 +211,11 @@ def _get_bias_deps(
210211 self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
211212 ) -> Tuple [bool , List [torch .fx .Node ]]:
212213 gemm_deps = []
213- if precision == ConfigPrecisionType .FP32 and self .force_fp32_dynamic_linear :
214- # if force force_fp32_dynamic_linear is enabled, then we
214+ if (
215+ precision == ConfigPrecisionType .FP32
216+ and self .force_non_static_weights_for_f32_linear
217+ ):
218+ # if force_non_static_weights_for_f32_linear is enabled, then we
215219 # do not partition the weight node
216220 return (True , gemm_deps )
217221
@@ -287,8 +291,11 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
287291 def _get_weight_deps (
288292 self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
289293 ) -> Tuple [bool , List [torch .fx .Node ]]:
290- if precision == ConfigPrecisionType .FP32 and self .force_fp32_dynamic_linear :
291- # if force fp32_dynamic_linear is enabled, then we
294+ if (
295+ precision == ConfigPrecisionType .FP32
296+ and self .force_non_static_weights_for_f32_linear
297+ ):
298+ # if force_non_static_weights_for_f32_linear is enabled, then we
292299 # do not partition the weight node
293300 return (True , [])
294301
@@ -394,9 +401,11 @@ def __init__(self, **kwargs):
394401 def _get_weight_deps (
395402 self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
396403 ) -> Tuple [bool , List [torch .fx .Node ]]:
397- # TODO(maxren, T210537195):
398- if precision == ConfigPrecisionType .FP32 and self .force_fp32_dynamic_linear :
399- # if force fp32_dynamic_linear is on and we detected this as fp32, then we
404+ if (
405+ precision == ConfigPrecisionType .FP32
406+ and self .force_non_static_weights_for_f32_linear
407+ ):
408+ # if force_non_static_weights_for_f32_linear is on and we detected this as fp32, then we
400409 # do not partition the weight node
401410 return (True , [])
402411
@@ -482,11 +491,11 @@ def find_partition_args(input_node):
482491 node .args = old_args
483492 node .users = old_users
484493
485- # When using force_fp32_dynamic_linear , we want to get_deps to overwrite the source partition nodes.
494+ # When using force_non_static_weights_for_f32_linear , we want to get_deps to overwrite the source partition nodes.
486495 # Else we want to be greedy.
487496 ret_deps = (
488497 list (set (deps ) & set (src_partition .nodes ))
489- if self .force_fp32_dynamic_linear
498+ if self .force_non_static_weights_for_f32_linear
490499 else list (set (deps ) | set (src_partition .nodes ))
491500 )
492501
@@ -512,8 +521,11 @@ def __init__(self, **kwargs):
512521 def _get_weight_deps (
513522 self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
514523 ) -> Tuple [bool , List [torch .fx .Node ]]:
515- if precision == ConfigPrecisionType .FP32 and self .force_fp32_dynamic_linear :
516- # if force fp32_dynamic_linear is on and we detected this as fp32, then we
524+ if (
525+ precision == ConfigPrecisionType .FP32
526+ and self .force_non_static_weights_for_f32_linear
527+ ):
528+ # if force_non_static_weights_for_f32_linear is on and we detected this as fp32, then we
517529 # do not partition the weight node
518530 return (True , [])
519531
0 commit comments