Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions backends/xnnpack/partition/config/gemm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def _get_weight_deps(
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
) -> Tuple[bool, List[torch.fx.Node]]:
gemm_deps = []
breakpoint()
if precision == ConfigPrecisionType.FP32:
# First find the weight
weight_node = get_input_node(node, self.weight_idx)
Expand Down Expand Up @@ -272,6 +273,17 @@ def _get_weight_deps(

return super()._get_weight_deps(node, ep, precision)

def _get_bias_deps(
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
) -> Tuple[bool, List[torch.fx.Node]]:
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
# if force fp32_dynamic_linear is on and we detected this as fp32, then we
# do not partition the weight node
breakpoint()
return (True, [])

return super()._get_bias_deps(node, ep, precision)

def supported_precision_types(self):
return [
ConfigPrecisionType.DYNAMIC_QUANT,
Expand Down Expand Up @@ -366,6 +378,27 @@ def get_deps(

return super().get_deps(node, ep)

def _get_weight_deps(
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
) -> Tuple[bool, List[torch.fx.Node]]:
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
# if force fp32_dynamic_linear is on and we detected this as fp32, then we
# do not partition the weight node
return (True, [])

return super()._get_weight_deps(node, ep, precision)

def _get_bias_deps(
self, node: torch.fx.Node, ep: ExportedProgram, precision: ConfigPrecisionType
) -> Tuple[bool, List[torch.fx.Node]]:
if precision == ConfigPrecisionType.FP32 and self.force_fp32_dynamic_linear:
# if force fp32_dynamic_linear is on and we detected this as fp32, then we
# do not partition the weight node
breakpoint()
return (True, [])

return super()._get_bias_deps(node, ep, precision)

def get_deps_from_src_partition(
self, node: torch.fx.Node, ep: ExportedProgram, src_partition: SourcePartition
):
Expand Down
2 changes: 1 addition & 1 deletion backends/xnnpack/partition/config/xnnpack_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, **kwargs):
super().__init__()
self.enabled_precision_types = self.supported_precision_types()
# Flag used in GEMMConfig()
self.force_fp32_dynamic_linear = kwargs.get("force_fp32_dynamic_linear", False)
self.force_fp32_dynamic_linear = kwargs.get("force_fp32_dynamic_linear", True)

def get_partition(
self, node: torch.fx.Node, ep: ExportedProgram
Expand Down
3 changes: 3 additions & 0 deletions examples/xnnpack/aot_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
# TODO(T165162973): This pass shall eventually be folded into quantizer
model = quantize(model, example_inputs)

breakpoint()
edge = to_edge_transform_and_lower(
ep,
partitioner=[XnnpackPartitioner()],
Expand All @@ -110,6 +111,8 @@
config=ExecutorchBackendConfig(extract_delegate_segments=False)
)

breakpoint()

if args.etrecord is not None:
generate_etrecord(args.etrecord, edge_copy, exec_prog)
logging.info(f"Saved ETRecord to {args.etrecord}")
Expand Down
Loading