Skip to content

Commit 56f68a8

Browse files
authored
Remove _skip_type_promotion config
Differential Revision: D77619493 Pull Request resolved: #12149
1 parent a909b83 commit 56f68a8

File tree

6 files changed

+2
-16
lines changed

6 files changed

+2
-16
lines changed

examples/apple/coreml/llama/export.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,6 @@ def main() -> None:
206206
],
207207
compile_config=EdgeCompileConfig(
208208
_check_ir_validity=False,
209-
_skip_type_promotion=(float_dtype == torch.float16),
210209
_skip_dim_order=True,
211210
),
212211
)

exir/capture/_config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ class EdgeCompileConfig:
4343
_core_aten_ops_exception_list: List[torch._ops.OpOverload] = field(
4444
default_factory=list
4545
)
46-
_skip_type_promotion: bool = False
4746
# TODO(gasoonjia): remove this
4847
_skip_dim_order: bool = False
4948

exir/program/_program.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -652,9 +652,7 @@ def _get_aten_to_edge_passes(config: EdgeCompileConfig):
652652
# well with node.meta, meaning after some passes permuting operators, we may lose some information in node.meta.
653653
# It might be regenerated in SpecPropPass so it may not be visiable. However debug handle will be lost.
654654

655-
pre_op_replace_passes = base_pre_op_replace_passes + (
656-
[] if config._skip_type_promotion else [RemoveMixedTypeOperators()]
657-
)
655+
pre_op_replace_passes = base_pre_op_replace_passes + [RemoveMixedTypeOperators()]
658656

659657
post_op_replace_passes = base_post_op_replace_passes
660658

extension/llm/export/builder.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,6 @@ def _get_dynamic_shape(self) -> Any:
214214
def _get_edge_config(self) -> EdgeCompileConfig:
215215
edge_config = EdgeCompileConfig(
216216
_check_ir_validity=False,
217-
_skip_type_promotion=bool(self.dtype == DType.fp16),
218217
_skip_dim_order=True,
219218
)
220219
return edge_config

test/end2end/exported_module.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def export(
6767
ignore_to_out_var_failure: bool = False,
6868
dynamic_memory_planning_mode: DynamicMemoryPlanningMode = DynamicMemoryPlanningMode.UPPER_BOUND,
6969
capture_config=None,
70-
skip_type_promotion: bool = False,
7170
export_joint_graph: bool = False,
7271
external_constants: bool = False,
7372
export_state_names: bool = False,
@@ -194,7 +193,7 @@ def __init__(self, method):
194193
exec_prog = to_edge(
195194
exported_methods,
196195
compile_config=exir.EdgeCompileConfig(
197-
_check_ir_validity=False, _skip_type_promotion=skip_type_promotion
196+
_check_ir_validity=False,
198197
),
199198
).to_executorch(
200199
ExecutorchBackendConfig(

test/models/export_program.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,6 @@ def get_random_inputs(self):
269269

270270
def export_module_to_program(
271271
module_class: Type[nn.Module],
272-
skip_type_promotion: bool,
273272
external_constants: bool = False,
274273
) -> ExecutorchProgramManager:
275274
"""Exports the module and returns the serialized program data."""
@@ -293,7 +292,6 @@ def export_module_to_program(
293292
module = ExportedModule.export(
294293
module_class,
295294
methods,
296-
skip_type_promotion=skip_type_promotion,
297295
export_joint_graph=export_joint,
298296
external_constants=external_constants,
299297
export_state_names=export_state_names,
@@ -342,17 +340,11 @@ def main() -> None:
342340
# Export and write to the output files.
343341
os.makedirs(args.outdir, exist_ok=True)
344342
for module_name, module_class in module_names_to_classes.items():
345-
skip_type_promotion = False
346-
if module_name == "ModuleAddHalf":
347-
# Skip type promotion to keep the model in fp16.
348-
# Type promotion will convert to fp32.
349-
skip_type_promotion = True
350343
if args.external_constants:
351344
module_name = f"{module_name}Program"
352345
outfile = os.path.join(args.outdir, f"{module_name}.pte")
353346
prog = export_module_to_program(
354347
module_class,
355-
skip_type_promotion=skip_type_promotion,
356348
external_constants=args.external_constants,
357349
)
358350
with open(outfile, "wb") as fp:

0 commit comments

Comments
 (0)