@@ -96,9 +96,9 @@ def _detect_precision(self, node: torch.fx.Node) -> ConfigPrecisionType:
9696 def _overwrite_precision (self , node : torch .fx .Node ):
9797 precision = self ._detect_precision (node )
9898 if precision not in self .enabled_precision_types :
99- # detected precision is not enabled, lets try to partition it as fp32
99+ # detected precision is not enabled, try to partition it as fp32
100100 if self .enabled_precision_types == [ConfigPrecisionType .FP32 ]:
101- # if only fp32 is enabled, then we can still partition fp32 gemms
101+ # when only fp32 is enabled, then we can still partition fp32 gemms
102102 # even with in a quantized graph
103103 if precision in [
104104 ConfigPrecisionType .STATIC_QUANT ,
@@ -107,6 +107,7 @@ def _overwrite_precision(self, node: torch.fx.Node):
107107 precision = ConfigPrecisionType .FP32
108108 logging .info (f"Overwriting precision, partitioning { node } as FP32" )
109109 return True , precision
110+
110111 return False , precision
111112
112113 def get_deps (
@@ -124,7 +125,6 @@ def get_deps(
124125 # detected precision but it is either disabled or not supported
125126 why (node , f"Unsupported precision type { precision } " )
126127 return (False , [])
127- _ , precision = self ._overwrite_precision (node )
128128 valid_bias , bias_deps = self ._get_bias_deps (node , ep , precision )
129129 valid_weight , weight_deps = self ._get_weight_deps (node , ep , precision )
130130 valid_act , act_deps = self ._get_act_deps (node , ep , precision )
@@ -139,6 +139,11 @@ def _get_weight_deps(
139139 self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
140140 ) -> Tuple [bool , List [torch .fx .Node ]]:
141141 gemm_deps = []
142+ if precision == ConfigPrecisionType .FP32 and self .force_non_static_weights_for_f32_linear :
143+ # if force_non_static_weights_for_f32_linear is enabled, then we
144+ # do not partition the weight node
145+ return (True , gemm_deps )
146+
142147 if precision == ConfigPrecisionType .FP32 :
143148 # First find the weight
144149 weight_node = get_input_node (node , self .weight_idx )
@@ -220,8 +225,8 @@ def _get_bias_deps(
220225 self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
221226 ) -> Tuple [bool , List [torch .fx .Node ]]:
222227 gemm_deps = []
223- if precision == ConfigPrecisionType .FP32 and self .force_fp32_dynamic_linear :
224- # if force force_fp32_dynamic_linear is enabled, then we
228+ if precision == ConfigPrecisionType .FP32 and self .force_non_static_weights_for_f32_linear :
229+ # if force for_fp32_linear_as_matmul is enabled, then we
225230 # do not partition the weight node
226231 return (True , gemm_deps )
227232
@@ -299,11 +304,6 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
299304 def _get_weight_deps (
300305 self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
301306 ) -> Tuple [bool , List [torch .fx .Node ]]:
302- if precision == ConfigPrecisionType .FP32 and self .force_fp32_dynamic_linear :
303- # if force fp32_dynamic_linear is enabled, then we
304- # do not partition the weight node
305- return (True , [])
306-
307307 # Since we are in Linear, we may assume that the weights are indeed static.
308308 overwritten_linear_precision , new_precision = self ._overwrite_precision (node )
309309 if new_precision == ConfigPrecisionType .FP32 and overwritten_linear_precision :
@@ -403,17 +403,6 @@ def __init__(self, **kwargs):
403403 self .src_partitions = None
404404 self .linear_modules = [torch .nn .functional .linear , torch .nn .Linear ]
405405
406- def _get_weight_deps (
407- self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
408- ) -> Tuple [bool , List [torch .fx .Node ]]:
409- # TODO(maxren, T210537195):
410- if precision == ConfigPrecisionType .FP32 and self .force_fp32_dynamic_linear :
411- # if force fp32_dynamic_linear is on and we detected this as fp32, then we
412- # do not partition the weight node
413- return (True , [])
414-
415- return super ()._get_weight_deps (node , ep , precision )
416-
417406 def get_deps (
418407 self ,
419408 node : torch .fx .Node ,
@@ -495,11 +484,11 @@ def find_partition_args(input_node):
495484 node .args = old_args
496485 node .users = old_users
497486
498- # When using force_fp32_dynamic_linear , we want to get_deps to overwrite the source partition nodes.
487+ # When using force_non_static_weights_for_f32_linear , we want to get_deps to overwrite the source partition nodes.
499488 # Else we want to be greedy.
500489 ret_deps = (
501490 list (set (deps ) & set (src_partition .nodes ))
502- if self .force_fp32_dynamic_linear
491+ if self .force_non_static_weights_for_f32_linear
503492 else list (set (deps ) | set (src_partition .nodes ))
504493 )
505494
@@ -522,16 +511,6 @@ def __init__(self, **kwargs):
522511 self .weight_idx = 1
523512 self .act_idx = 0
524513
525- def _get_weight_deps (
526- self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
527- ) -> Tuple [bool , List [torch .fx .Node ]]:
528- if precision == ConfigPrecisionType .FP32 and self .force_fp32_dynamic_linear :
529- # if force fp32_dynamic_linear is on and we detected this as fp32, then we
530- # do not partition the weight node
531- return (True , [])
532-
533- return super ()._get_weight_deps (node , ep , precision )
534-
535514 def supported_precision_types (self ):
536515 return [
537516 ConfigPrecisionType .FP32 ,
0 commit comments