Skip to content

Commit 7407150

Browse files
chunnienccopybara-github
authored andcommitted
odmltorch tf integration cleanup
PiperOrigin-RevId: 711819517
1 parent 3b2afed commit 7407150

File tree

3 files changed

+1
-73
lines changed

3 files changed

+1
-73
lines changed

ai_edge_torch/odml_torch/export.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -222,11 +222,6 @@ def __call__(self, *args):
222222
# Lazy importing TF when execution is needed.
223223
return self.tf_function(*args)
224224

225-
def to_flatbuffer(self):
226-
from . import tf_integration
227-
228-
return tf_integration.mlir_to_flatbuffer(self)
229-
230225

231226
# TODO(b/331481564) Make this a ai_edge_torch FX pass.
232227
def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):

ai_edge_torch/odml_torch/test/test_tf_integration.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -44,29 +44,6 @@ def forward(self, x, y):
4444

4545
self.assertTrue(np.allclose(lowering_output, torch_output))
4646

47-
@googletest.skip("b/353280409")
48-
def test_dynamic_mlir_lowered_call(self):
49-
class AddModel(torch.nn.Module):
50-
51-
def forward(self, x, y):
52-
return x + y + x + y
53-
54-
model = AddModel().eval()
55-
batch = torch.export.Dim("batch")
56-
ep = torch.export.export(
57-
model,
58-
(torch.rand((2, 10)), torch.rand((2, 10))),
59-
dynamic_shapes={"x": {0: batch}, "y": {0: batch}},
60-
)
61-
62-
lowered = odml_torch.export.exported_program_to_mlir(ep)
63-
64-
val_args = (torch.rand((10, 10)), torch.rand((10, 10)))
65-
torch_output = model(*val_args).detach().numpy()
66-
lowering_output = np.array(lowered(*val_args))
67-
68-
self.assertTrue(np.allclose(lowering_output, torch_output))
69-
7047
def test_resnet18(self):
7148
model = torchvision.models.resnet18().eval()
7249
forward_args = lambda: (torch.rand((1, 3, 224, 224)),)

ai_edge_torch/odml_torch/tf_integration.py

Lines changed: 1 addition & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
"""APIs to convert lowered MLIR from PyTorch to TensorFlow and TFLite artifacts."""
15+
"""APIs to convert lowered MLIR from PyTorch to TensorFlow artifacts."""
1616

1717
import re
18-
import tempfile
1918

2019
import tensorflow as tf
2120
import torch
@@ -155,46 +154,3 @@ def mlir_to_tf_function(lowered: export.MlirLowered):
155154
_wrap_as_tf_func(lowered, tf_state_dict),
156155
input_signature=_make_input_signatures(lowered),
157156
)
158-
159-
160-
def mlir_to_flatbuffer(lowered: export.MlirLowered):
161-
"""Convert the MLIR lowered to a TFLite flatbuffer binary."""
162-
tf_state_dict = _build_tf_state_dict(lowered)
163-
signature_names = [tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
164-
tf_signatures = [_make_input_signatures(lowered)]
165-
tf_functions = [_wrap_as_tf_func(lowered, tf_state_dict)]
166-
167-
tf_module = tf.Module()
168-
tf_module.f = []
169-
170-
for tf_sig, func in zip(tf_signatures, tf_functions):
171-
tf_module.f.append(
172-
tf.function(
173-
func,
174-
input_signature=tf_sig,
175-
)
176-
)
177-
178-
tf_module._variables = list(tf_state_dict.values())
179-
180-
tf_concrete_funcs = [
181-
func.get_concrete_function(*tf_sig)
182-
for func, tf_sig in zip(tf_module.f, tf_signatures)
183-
]
184-
185-
# We need to temporarily save since TFLite's from_concrete_functions does not
186-
# allow providing names for each of the concrete functions.
187-
with tempfile.TemporaryDirectory() as temp_dir_path:
188-
tf.saved_model.save(
189-
tf_module,
190-
temp_dir_path,
191-
signatures={
192-
sig_name: tf_concrete_funcs[idx]
193-
for idx, sig_name in enumerate(signature_names)
194-
},
195-
)
196-
197-
converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir_path)
198-
tflite_model = converter.convert()
199-
200-
return tflite_model

0 commit comments

Comments
 (0)