Skip to content

Commit 9d387ec

Browse files
chunnienccopybara-github
authored andcommitted
fix rand and randn lowering
PiperOrigin-RevId: 708361782
1 parent dc45276 commit 9d387ec

File tree

5 files changed

+154
-4
lines changed

5 files changed

+154
-4
lines changed

ai_edge_torch/odml_torch/export.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,12 @@ def module_bytecode_vhlo(self) -> bytes:
198198
# build, which may not have the same StableHLO version as what used in
199199
# TFLite converter. Therefore we always serialize MLIR module in VHLO.
200200
# TODO(b/362798610) Build MLIR pybinding in ai-edge-torch release.
201-
target_version = stablehlo.get_minimum_version()
201+
if stablehlo.get_api_version() < 9:
202+
target_version = stablehlo.get_minimum_version()
203+
else:
204+
target_version = stablehlo.get_version_from_compatibility_requirement(
205+
stablehlo.StablehloCompatibilityRequirement.WEEK_4
206+
)
202207
module_bytecode = xla_extension.mlir.serialize_portable_artifact(
203208
self.module_bytecode, target_version
204209
)

ai_edge_torch/odml_torch/jax_bridge/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
from ai_edge_torch.odml_torch.jax_bridge._wrap import wrap
15+
from ai_edge_torch.odml_torch.jax_bridge import _wrap
16+
from ai_edge_torch.odml_torch.jax_bridge import utils
17+
18+
wrap = _wrap.wrap

ai_edge_torch/odml_torch/lowerings/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from . import _jax_lowerings
1919
from . import _layer_norm
2020
from . import _quantized_decomposed
21+
from . import _rand
2122
from . import context
2223
from . import registry
2324
from . import utils

ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
LoweringContext = context.LoweringContext
2828

29+
2930
@functools.cache
3031
def _log_usage(op):
3132
logging.warning("Use jax lowering: %s", str(op))
@@ -184,8 +185,6 @@ def lower_by_torch_xla2(op):
184185
lower_by_torch_xla2(torch.ops.aten.pixel_shuffle)
185186
lower_by_torch_xla2(torch.ops.aten.pow)
186187
lower_by_torch_xla2(torch.ops.aten.prod)
187-
lower_by_torch_xla2(torch.ops.aten.rand)
188-
lower_by_torch_xla2(torch.ops.aten.randn)
189188
lower_by_torch_xla2(torch.ops.aten.reciprocal)
190189
lower_by_torch_xla2(torch.ops.aten.reflection_pad1d)
191190
lower_by_torch_xla2(torch.ops.aten.relu)
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# Copyright 2024 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
import uuid
16+
17+
from ai_edge_torch.odml_torch import export_utils
18+
from ai_edge_torch.odml_torch.lowerings import context
19+
from ai_edge_torch.odml_torch.lowerings import registry
20+
from jax._src.lib.mlir import ir
21+
from jax._src.lib.mlir.dialects import func
22+
from jax._src.lib.mlir.dialects import hlo as stablehlo
23+
import numpy as np
24+
import torch
25+
import torch.utils._pytree as pytree
26+
27+
LoweringContext = context.LoweringContext
28+
lower = registry.lower
29+
30+
31+
def _random_lowering(
32+
lctx: LoweringContext,
33+
size: list[int],
34+
generator,
35+
dtype: torch.dtype,
36+
rand_tensor,
37+
composite_name: str,
38+
):
39+
if dtype is None:
40+
dtype = torch.float32
41+
42+
rand_tensor = rand_tensor.type(dtype)
43+
data = rand_tensor.detach().numpy()
44+
45+
shape, _ = pytree.tree_flatten(size)
46+
elty = export_utils.torch_dtype_to_ir_element_type(dtype)
47+
48+
decomp_name = f"{composite_name}.impl_{uuid.uuid4().hex[:8]}"
49+
50+
with ir.InsertionPoint(lctx.ir_module.body):
51+
52+
@func.FuncOp.from_py_func(
53+
ir.RankedTensorType.get(
54+
[len(shape)],
55+
ir.IntegerType.get_signless(32),
56+
),
57+
name=decomp_name,
58+
)
59+
def _rand_impl(_):
60+
return [stablehlo.constant(ir.DenseElementsAttr.get(data))]
61+
62+
seed, seed2 = (
63+
torch.randint(
64+
torch.iinfo(torch.int64).min,
65+
torch.iinfo(torch.int64).max,
66+
(2,),
67+
dtype=torch.int64,
68+
generator=generator,
69+
)
70+
.detach()
71+
.numpy()
72+
)
73+
74+
shape_ = stablehlo.constant(
75+
ir.DenseElementsAttr.get(np.array(shape, dtype=np.int32))
76+
)
77+
return stablehlo.CompositeOp(
78+
result=[ir.RankedTensorType.get(shape, elty)],
79+
inputs=[shape_],
80+
name=composite_name,
81+
composite_attributes=ir.DictAttr.get({
82+
"seed": ir.IntegerAttr.get(ir.IntegerType.get_signless(64), seed),
83+
"seed2": ir.IntegerAttr.get(ir.IntegerType.get_signless(64), seed2),
84+
}),
85+
decomposition=decomp_name,
86+
).results[0]
87+
88+
89+
# Schema:
90+
# - aten::rand(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None,
91+
# Device? device=None, bool? pin_memory=None) -> Tensor
92+
# - aten::rand.generator(SymInt[] size, *, Generator? generator,
93+
# ScalarType? dtype=None, Layout? layout=None, Device? device=None,
94+
# bool? pin_memory=None) -> Tensor
95+
@registry.lower(torch.ops.aten.rand)
96+
def _aten_rand(
97+
lctx: LoweringContext,
98+
size,
99+
generator=None,
100+
dtype=None,
101+
layout=torch.strided,
102+
device=None,
103+
pin_memory=False,
104+
):
105+
return _random_lowering(
106+
lctx,
107+
size,
108+
generator,
109+
dtype,
110+
rand_tensor=torch.ops.aten.rand.generator(
111+
size, generator=generator, dtype=dtype
112+
),
113+
composite_name="odml.random_uniform",
114+
)
115+
116+
117+
# Schema:
118+
# - aten::randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None,
119+
# Device? device=None, bool? pin_memory=None) -> Tensor
120+
# - aten::randn.generator(SymInt[] size, *, Generator? generator,
121+
# ScalarType? dtype=None, Layout? layout=None, Device? device=None,
122+
# bool? pin_memory=None) -> Tensor
123+
@registry.lower(torch.ops.aten.randn)
124+
def _aten_randn(
125+
lctx: LoweringContext,
126+
size,
127+
generator=None,
128+
dtype=None,
129+
layout=torch.strided,
130+
device=None,
131+
pin_memory=False,
132+
):
133+
return _random_lowering(
134+
lctx,
135+
size,
136+
generator,
137+
dtype,
138+
rand_tensor=torch.ops.aten.randn.generator(
139+
size, generator=generator, dtype=dtype
140+
),
141+
composite_name="odml.random_standard_normal",
142+
)

0 commit comments

Comments
 (0)