Skip to content

Commit e779ac6

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Allow addmm and mm to call dynamic fp32 kernels
Summary: Allow addmm and mm to call the dynamic weight kernels. Differential Revision: D66898281
1 parent ec56da8 commit e779ac6

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

backends/xnnpack/partition/config/gemm_configs.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,16 @@ def __init__(self, **kwargs):
337337
self.src_partitions = None
338338
self.linear_modules = [torch.nn.functional.linear, torch.nn.Linear]
339339

340+
def _get_weight_deps(
341+
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
342+
) -> Tuple[bool, List[torch.fx.Node]]:
343+
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
344+
# if force fp32_dynamic_linear is on and we detected this as fp32, then we
345+
# do not partition the weight node
346+
return (True, [])
347+
348+
return super()._get_weight_deps(node, ep, precision)
349+
340350
def get_deps(
341351
self,
342352
node: torch.fx.Node,
@@ -436,6 +446,16 @@ def __init__(self, **kwargs):
436446
self.weight_idx = 1
437447
self.act_idx = 0
438448

449+
def _get_weight_deps(
450+
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
451+
) -> Tuple[bool, List[torch.fx.Node]]:
452+
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
453+
# if force fp32_dynamic_linear is on and we detected this as fp32, then we
454+
# do not partition the weight node
455+
return (True, [])
456+
457+
return super()._get_weight_deps(node, ep, precision)
458+
439459
def supported_precision_types(self):
440460
return [
441461
ConfigPrecisionType.FP32,

0 commit comments

Comments
 (0)