66
77from typing import Callable , Optional , Any
88from os import PathLike
9+ from pathlib import Path
910import functools
1011
1112import torch
1617from torch .utils ._pytree import tree_structure , tree_unflatten , tree_flatten
1718from amdsharktank .types .tensors import ShardedTensor
1819from amdsharktank .types .theta import mark_export_external_theta
19- from amdsharktank .layers import BaseLayer , ThetaLayer
20+
21+ from typing import TYPE_CHECKING
22+
23+ if TYPE_CHECKING :
24+ from amdsharktank .layers import BaseLayer , ThetaLayer
25+
26+ # from amdsharktank.layers import BaseLayer, ThetaLayer
2027
2128
2229def flatten_signature (
@@ -180,7 +187,7 @@ def flat_fn(*args, **kwargs):
180187
181188
182189def export_model_mlir (
183- model : BaseLayer ,
190+ model : " BaseLayer" ,
184191 output_path : PathLike ,
185192 * ,
186193 function_batch_sizes_map : Optional [dict [Optional [str ], list [int ]]] = None ,
@@ -202,7 +209,7 @@ def export_model_mlir(
202209
203210 assert not (function_batch_sizes_map is not None and batch_sizes is not None )
204211
205- if isinstance (model , ThetaLayer ):
212+ if isinstance (model , " ThetaLayer" ):
206213 mark_export_external_theta (model .theta )
207214
208215 if batch_sizes is not None :
@@ -317,7 +324,7 @@ def export_torch_module_to_mlir_file(
317324 def _ (module , * fn_args ):
318325 return module .forward (* fn_args )
319326
320- export_output = export (fxb , import_symbolic_shape_expressions = True )
327+ export_output = aot . export (fxb , import_symbolic_shape_expressions = True )
321328 export_output .save_mlir (mlir_path )
322329
323330 return export_output
0 commit comments