Skip to content

Commit 0704751

Browse files
chunnienccopybara-github
authored andcommitted
Add internal kwarg to set intermediate saved model dir
PiperOrigin-RevId: 706748161
1 parent 8e77ee1 commit 0704751

File tree

5 files changed

+27
-1
lines changed

5 files changed

+27
-1
lines changed

ai_edge_torch/_convert/conversion.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ def convert_signatures(
7878
*,
7979
strict_export: Union[Literal["auto"], bool] = True,
8080
quant_config: Optional[qcfg.QuantConfig] = None,
81-
_tfl_converter_flags: Optional[dict[str, Any]],
81+
_tfl_converter_flags: Optional[dict[str, Any]] = None,
82+
_saved_model_dir: Optional[str] = None,
8283
) -> model.TfLiteModel:
8384
"""Converts a list of `signature.Signature`s and embeds them into one `model.TfLiteModel`.
8485
@@ -93,6 +94,8 @@ def convert_signatures(
9394
quant_config: User-defined quantization method and scheme of the model.
9495
_tfl_converter_flags: A nested dictionary allowing setting flags for the
9596
underlying tflite converter.
97+
_saved_model_dir: Directory for the intermediate saved model. If not
98+
specified, a random temporary directory would be used.
9699
97100
Returns:
98101
The converted `model.TfLiteModel` object.
@@ -140,6 +143,7 @@ def export(*args, **kwargs):
140143
signatures,
141144
quant_config=quant_config,
142145
_tfl_converter_flags=_tfl_converter_flags,
146+
_saved_model_dir=_saved_model_dir,
143147
)
144148

145149
return model.TfLiteModel(tflite_model)

ai_edge_torch/_convert/converter.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def convert(
106106
quant_config: Optional[qcfg.QuantConfig] = None,
107107
dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
108108
_ai_edge_converter_flags: Optional[dict[str, Any]] = None,
109+
_saved_model_dir: Optional[str] = None,
109110
) -> model.TfLiteModel:
110111
"""Finalizes the conversion and produces an edge model.
111112
@@ -139,6 +140,8 @@ def convert(
139140
of this function and so needs to be treated as such. Please do not rely
140141
on this parameter except for local debugging as this can be removed in a
141142
future release.
143+
_saved_model_dir: Directory for the intermediate saved model. If not
144+
specified, a random temporary directory would be used.
142145
143146
Returns:
144147
The converted edge model.
@@ -171,6 +174,7 @@ def convert(
171174
strict_export=strict_export,
172175
quant_config=quant_config,
173176
_tfl_converter_flags=_ai_edge_converter_flags,
177+
_saved_model_dir=_saved_model_dir,
174178
)
175179

176180

@@ -216,6 +220,7 @@ def convert(
216220
quant_config: Optional[qcfg.QuantConfig] = None,
217221
dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
218222
_ai_edge_converter_flags: Optional[dict[str, Any]] = None,
223+
_saved_model_dir: Optional[str] = None,
219224
) -> model.TfLiteModel:
220225
"""Converts a PyTorch model to an edge model with a default signature.
221226
@@ -240,6 +245,8 @@ def convert(
240245
this function and so needs to be treated as such. Please do not rely on
241246
this parameter except for local debugging as this can be removed in a
242247
future release.
248+
_saved_model_dir: Directory for the intermediate saved model. If not
249+
specified, a random temporary directory would be used.
243250
244251
Returns:
245252
The converted edge model.
@@ -259,4 +266,5 @@ def convert(
259266
quant_config=quant_config,
260267
dynamic_shapes=dynamic_shapes,
261268
_ai_edge_converter_flags=_ai_edge_converter_flags,
269+
_saved_model_dir=_saved_model_dir,
262270
)

ai_edge_torch/lowertools/_shim.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,16 @@ def exported_programs_to_tflite(
5050
*,
5151
quant_config: Optional[qcfg.QuantConfig] = None,
5252
_tfl_converter_flags: Optional[dict[str, Any]] = None,
53+
_saved_model_dir: Optional[str] = None
5354
):
5455
"""Converts a list of ExportedProgram to a TFLite model.
5556
5657
Args:
5758
exported_programs: A list of ExportedProgram.
5859
signatures: A list of Signature.
5960
quant_config: A QuantConfig.
61+
_saved_model_dir: Directory for the intermediate saved model. If not
62+
specified, a random temporary directory would be used.
6063
_tfl_converter_flags: A dict of flags for TFLiteConverter.
6164
6265
Returns:
@@ -79,4 +82,5 @@ def exported_programs_to_tflite(
7982
signatures,
8083
quant_config=quant_config,
8184
_tfl_converter_flags=_tfl_converter_flags,
85+
_saved_model_dir=_saved_model_dir,
8286
)

ai_edge_torch/lowertools/odml_torch_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def merged_bundle_to_tfl_model(
138138
*,
139139
quant_config: Optional[qcfg.QuantConfig] = None,
140140
_tfl_converter_flags: dict = {},
141+
_saved_model_dir: Optional[str] = None,
141142
):
142143
tf_state_dict = merged_bundle.bundles[0].state_dict
143144

@@ -173,6 +174,9 @@ def merged_bundle_to_tfl_model(
173174
# We need to temporarily save since TFLite's from_concrete_functions does not
174175
# allow providing names for each of the concrete functions.
175176
with tempfile.TemporaryDirectory() as temp_dir_path:
177+
if _saved_model_dir is not None:
178+
temp_dir_path = _saved_model_dir
179+
176180
tf.saved_model.save(
177181
tf_module,
178182
temp_dir_path,

ai_edge_torch/lowertools/torch_xla_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def merged_bundle_to_tfl_model(
192192
*,
193193
quant_config: Optional[qcfg.QuantConfig] = None,
194194
_tfl_converter_flags: dict = {},
195+
_saved_model_dir: Optional[str] = None,
195196
) -> None:
196197
"""Converts a StableHLOGraphModule to a tflite model.
197198
@@ -200,6 +201,8 @@ def merged_bundle_to_tfl_model(
200201
signatures: List of signatures from which names of the signatures is
201202
extracted.
202203
quant_config: User-defined quantization method and scheme of the model.
204+
_saved_model_dir: Directory for the intermediate saved model. If not
205+
specified, a random temporary directory would be used.
203206
_tfl_converter_flags: A nested dictionary allowing setting flags for the
204207
underlying tflite converter.
205208
"""
@@ -246,6 +249,9 @@ def merged_bundle_to_tfl_model(
246249
# We need to temporarily save since TFLite's from_concrete_functions does not
247250
# allow providing names for each of the concrete functions.
248251
with tempfile.TemporaryDirectory() as temp_dir_path:
252+
if _saved_model_dir is not None:
253+
temp_dir_path = _saved_model_dir
254+
249255
tf.saved_model.save(
250256
tf_module,
251257
temp_dir_path,

0 commit comments

Comments
 (0)