@@ -122,6 +122,7 @@ def _get_weight_deps(
122122 self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
123123 ) -> Tuple [bool , List [torch .fx .Node ]]:
124124 gemm_deps = []
125+ breakpoint ()
125126 if precision == ConfigPrecisionType .FP32 :
126127 # First find the weight
127128 weight_node = get_input_node (node , self .weight_idx )
@@ -272,6 +273,17 @@ def _get_weight_deps(
272273
273274 return super ()._get_weight_deps (node , ep , precision )
274275
276+ def _get_bias_deps (
277+ self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
278+ ) -> Tuple [bool , List [torch .fx .Node ]]:
279+ if precision == ConfigPrecisionType .FP32 and self .force_fp32_dynamic_linear :
280+ # if force fp32_dynamic_linear is on and we detected this as fp32, then we
281+ # do not partition the weight node
282+ breakpoint ()
283+ return (True , [])
284+
285+ return super ()._get_bias_deps (node , ep , precision )
286+
275287 def supported_precision_types (self ):
276288 return [
277289 ConfigPrecisionType .DYNAMIC_QUANT ,
@@ -366,6 +378,27 @@ def get_deps(
366378
367379 return super ().get_deps (node , ep )
368380
381+ def _get_weight_deps (
382+ self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
383+ ) -> Tuple [bool , List [torch .fx .Node ]]:
384+ if precision == ConfigPrecisionType .FP32 and self .force_fp32_dynamic_linear :
385+ # if force fp32_dynamic_linear is on and we detected this as fp32, then we
386+ # do not partition the weight node
387+ return (True , [])
388+
389+ return super ()._get_weight_deps (node , ep , precision )
390+
391+ def _get_bias_deps (
392+ self , node : torch .fx .Node , ep : ExportedProgram , precision : ConfigPrecisionType
393+ ) -> Tuple [bool , List [torch .fx .Node ]]:
394+ if precision == ConfigPrecisionType .FP32 and self .force_fp32_dynamic_linear :
395+ # if force fp32_dynamic_linear is on and we detected this as fp32, then we
396+ # do not partition the weight node
397+ breakpoint ()
398+ return (True , [])
399+
400+ return super ()._get_bias_deps (node , ep , precision )
401+
369402 def get_deps_from_src_partition (
370403 self , node : torch .fx .Node , ep : ExportedProgram , src_partition : SourcePartition
371404 ):
0 commit comments