Skip to content

Commit b197abc

Browse files
committed
xnnpack io
1 parent 63017e4 commit b197abc

File tree

3 files changed

+37
-1
lines changed

3 files changed

+37
-1
lines changed

backends/xnnpack/partition/config/gemm_configs.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def _get_weight_deps(
122122
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
123123
) -> Tuple[bool, List[torch.fx.Node]]:
124124
gemm_deps = []
125+
breakpoint()
125126
if precision == ConfigPrecisionType.FP32:
126127
# First find the weight
127128
weight_node = get_input_node(node, self.weight_idx)
@@ -272,6 +273,17 @@ def _get_weight_deps(
272273

273274
return super()._get_weight_deps(node, ep, precision)
274275

276+
def _get_bias_deps(
277+
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
278+
) -> Tuple[bool, List[torch.fx.Node]]:
279+
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
280+
# if force fp32_dynamic_linear is on and we detected this as fp32, then we
281+
# do not partition the weight node
282+
breakpoint()
283+
return (True, [])
284+
285+
return super()._get_bias_deps(node, ep, precision)
286+
275287
def supported_precision_types(self):
276288
return [
277289
ConfigPrecisionType.DYNAMIC_QUANT,
@@ -366,6 +378,27 @@ def get_deps(
366378

367379
return super().get_deps(node, ep)
368380

381+
def _get_weight_deps(
382+
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
383+
) -> Tuple[bool, List[torch.fx.Node]]:
384+
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
385+
# if force fp32_dynamic_linear is on and we detected this as fp32, then we
386+
# do not partition the weight node
387+
return (True, [])
388+
389+
return super()._get_weight_deps(node, ep, precision)
390+
391+
def _get_bias_deps(
392+
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
393+
) -> Tuple[bool, List[torch.fx.Node]]:
394+
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
395+
# if force fp32_dynamic_linear is on and we detected this as fp32, then we
396+
# do not partition the weight node
397+
breakpoint()
398+
return (True, [])
399+
400+
return super()._get_bias_deps(node, ep, precision)
401+
369402
def get_deps_from_src_partition(
370403
self, node: torch.fx.Node, ep: ExportedProgram, src_partition: SourcePartition
371404
):

backends/xnnpack/partition/config/xnnpack_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(self, **kwargs):
4141
super().__init__()
4242
self.enabled_precision_types = self.supported_precision_types()
4343
# Flag used in GEMMConfig()
44-
self.force_fp32_dynamic_linear = kwargs.get("force_fp32_dynamic_linear", False)
44+
self.force_fp32_dynamic_linear = kwargs.get("force_fp32_dynamic_linear", True)
4545

4646
def get_partition(
4747
self, node: torch.fx.Node, ep: ExportedProgram

examples/xnnpack/aot_compiler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
# TODO(T165162973): This pass shall eventually be folded into quantizer
9494
model = quantize(model, example_inputs)
9595

96+
breakpoint()
9697
edge = to_edge_transform_and_lower(
9798
ep,
9899
partitioner=[XnnpackPartitioner()],
@@ -110,6 +111,8 @@
110111
config=ExecutorchBackendConfig(extract_delegate_segments=False)
111112
)
112113

114+
breakpoint()
115+
113116
if args.etrecord is not None:
114117
generate_etrecord(args.etrecord, edge_copy, exec_prog)
115118
logging.info(f"Saved ETRecord to {args.etrecord}")

0 commit comments

Comments
 (0)