Skip to content

Commit 167e20f

Browse files
chunnienccopybara-github
authored andcommitted
support bf16 in model signature conversion
PiperOrigin-RevId: 745667254
1 parent 51f7614 commit 167e20f

File tree

8 files changed

+94
-2
lines changed

8 files changed

+94
-2
lines changed

ai_edge_torch/_convert/conversion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ def _run_convert_passes(
4040
fx_passes.OptimizeLayoutTransposesPass(),
4141
fx_passes.CanonicalizePass(),
4242
fx_passes.BuildAtenCompositePass(),
43-
fx_passes.CanonicalizePass(),
4443
fx_passes.RemoveNonUserOutputsPass(),
44+
fx_passes.CastInputsBf16ToF32Pass(),
4545
fx_passes.CanonicalizePass(),
4646
]
4747

ai_edge_torch/_convert/fx_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass
1919
from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass
20+
from ai_edge_torch._convert.fx_passes.cast_inputs_bf16_to_f32_pass import CastInputsBf16ToF32Pass
2021
from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
2122
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
2223
from ai_edge_torch._convert.fx_passes.remove_non_user_outputs_pass import RemoveNonUserOutputsPass
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright 2025 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Pass to cast all inputs with torch.bfloat16 type to torch.float32."""
16+
17+
18+
from ai_edge_torch import fx_infra
19+
import torch
20+
21+
22+
def cast_f32(x):
23+
return x.to(torch.float32)
24+
25+
26+
class CastInputsBf16ToF32Pass(fx_infra.ExportedProgramPassBase):
27+
"""This pass casts all inputs with torch.bfloat16 type to torch.float32."""
28+
29+
def call(self, exported_program: torch.export.ExportedProgram):
30+
modified = False
31+
for node in exported_program.graph.nodes:
32+
if (
33+
node.op == "placeholder"
34+
and node.meta.get("val").dtype == torch.bfloat16
35+
):
36+
if not node.users:
37+
continue
38+
39+
modified = True
40+
user = next(iter(node.users))
41+
with exported_program.graph.inserting_before(user):
42+
cast_node = exported_program.graph.call_function(
43+
cast_f32,
44+
(node,),
45+
)
46+
node.replace_all_uses_with(cast_node)
47+
cast_node.replace_input_with(cast_node, node)
48+
49+
exported_program.graph_module.recompile()
50+
return fx_infra.ExportedProgramPassResult(exported_program, modified)

ai_edge_torch/_convert/test/test_convert.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,27 @@ def test_convert_resnet18_pt2e_per_channel(self):
553553
self.fail(f"PT2E conversion failed: {err}")
554554
# pylint: enable=broad-except
555555

556+
def test_convert_model_with_bfloat16_inputs(self):
557+
"""Test converting a simple model with torch.bfloat16 input.
558+
559+
bf16 inputs would remain in converted model signature but be casted to f32
560+
right after the model inputs.
561+
"""
562+
563+
class SampleModel(nn.Module):
564+
565+
def forward(self, x: torch.Tensor):
566+
return (x + 1) * 1.2
567+
568+
model = SampleModel().eval()
569+
args = (torch.randn(10, 10).to(torch.bfloat16),)
570+
# pylint: disable=broad-except
571+
try:
572+
ai_edge_torch.convert(model, args)
573+
except Exception as err:
574+
self.fail(f"Conversion failed with bloat16 inputs: {err}")
575+
# pylint: enable=broad-except
576+
556577

557578
if __name__ == "__main__":
558579
googletest.main()

ai_edge_torch/lowertools/odml_torch_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def torch_dtype_to_tf(dtype):
5252
torch.int32: tf.int32,
5353
torch.int16: tf.int16,
5454
torch.bool: tf.bool,
55+
torch.bfloat16: tf.bfloat16,
5556
}.get(dtype)
5657

5758

ai_edge_torch/odml_torch/lowerings/_basic.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,3 +301,22 @@ def _aten_slice_scatter(lctx, self, src, dim=0, start=None, end=None, step=1):
301301
)
302302
out = stablehlo.select(pred, self, src)
303303
return out
304+
305+
306+
# Schema:
307+
# - aten::_to_copy(Tensor self, *, ScalarType? dtype=None,
308+
# Layout? layout=None, Device? device=None, bool? pin_memory=None,
309+
# bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor
310+
@lower(torch.ops.aten._to_copy.default)
311+
def _aten_to_copy(
312+
lctx, x: ir.Value, dtype: torch.dtype | None = None, **kwargs
313+
):
314+
if not dtype:
315+
return x
316+
317+
return stablehlo.convert(
318+
ir.RankedTensorType.get(
319+
x.type.shape, utils.torch_dtype_to_ir_element_type(dtype)
320+
),
321+
x,
322+
)

ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ def lower_by_torch_xla2(op):
7474
lower_by_torch_xla2(torch.ops.aten._native_batch_norm_legit_no_training)
7575
lower_by_torch_xla2(torch.ops.aten._pdist_forward)
7676
lower_by_torch_xla2(torch.ops.aten._softmax)
77-
lower_by_torch_xla2(torch.ops.aten._to_copy)
7877
lower_by_torch_xla2(torch.ops.aten._unsafe_index)
7978
lower_by_torch_xla2(torch.ops.aten._unsafe_view)
8079
lower_by_torch_xla2(torch.ops.aten.acos)

ai_edge_torch/odml_torch/lowerings/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def torch_dtype_to_ir_element_type(dtype) -> ir.Type:
3737
torch.int16: functools.partial(ir.IntegerType.get_signless, 16),
3838
torch.int8: functools.partial(ir.IntegerType.get_signless, 8),
3939
torch.bool: functools.partial(ir.IntegerType.get_signless, 1),
40+
torch.bfloat16: ir.BF16Type.get,
4041
}[dtype]
4142
return ty_get()
4243

0 commit comments

Comments
 (0)