|  | 
|  | 1 | +# Copyright 2024 The AI Edge Torch Authors. | 
|  | 2 | +# | 
|  | 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 4 | +# you may not use this file except in compliance with the License. | 
|  | 5 | +# You may obtain a copy of the License at | 
|  | 6 | +# | 
|  | 7 | +#     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 8 | +# | 
|  | 9 | +# Unless required by applicable law or agreed to in writing, software | 
|  | 10 | +# distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 12 | +# See the License for the specific language governing permissions and | 
|  | 13 | +# limitations under the License. | 
|  | 14 | +# ============================================================================== | 
|  | 15 | +import uuid | 
|  | 16 | + | 
|  | 17 | +from ai_edge_torch.odml_torch import export_utils | 
|  | 18 | +from ai_edge_torch.odml_torch.lowerings import context | 
|  | 19 | +from ai_edge_torch.odml_torch.lowerings import registry | 
|  | 20 | +from jax._src.lib.mlir import ir | 
|  | 21 | +from jax._src.lib.mlir.dialects import func | 
|  | 22 | +from jax._src.lib.mlir.dialects import hlo as stablehlo | 
|  | 23 | +import numpy as np | 
|  | 24 | +import torch | 
|  | 25 | +import torch.utils._pytree as pytree | 
|  | 26 | + | 
|  | 27 | +LoweringContext = context.LoweringContext | 
|  | 28 | +lower = registry.lower | 
|  | 29 | + | 
|  | 30 | + | 
|  | 31 | +def _random_lowering( | 
|  | 32 | +    lctx: LoweringContext, | 
|  | 33 | +    size: list[int], | 
|  | 34 | +    generator, | 
|  | 35 | +    dtype: torch.dtype, | 
|  | 36 | +    rand_tensor, | 
|  | 37 | +    composite_name: str, | 
|  | 38 | +): | 
|  | 39 | +  if dtype is None: | 
|  | 40 | +    dtype = torch.float32 | 
|  | 41 | + | 
|  | 42 | +  rand_tensor = rand_tensor.type(dtype) | 
|  | 43 | +  data = rand_tensor.detach().numpy() | 
|  | 44 | + | 
|  | 45 | +  shape, _ = pytree.tree_flatten(size) | 
|  | 46 | +  elty = export_utils.torch_dtype_to_ir_element_type(dtype) | 
|  | 47 | + | 
|  | 48 | +  decomp_name = f"{composite_name}.impl_{uuid.uuid4().hex[:8]}" | 
|  | 49 | + | 
|  | 50 | +  with ir.InsertionPoint(lctx.ir_module.body): | 
|  | 51 | + | 
|  | 52 | +    @func.FuncOp.from_py_func( | 
|  | 53 | +        ir.RankedTensorType.get( | 
|  | 54 | +            [len(shape)], | 
|  | 55 | +            ir.IntegerType.get_signless(32), | 
|  | 56 | +        ), | 
|  | 57 | +        name=decomp_name, | 
|  | 58 | +    ) | 
|  | 59 | +    def _rand_impl(_): | 
|  | 60 | +      return [stablehlo.constant(ir.DenseElementsAttr.get(data))] | 
|  | 61 | + | 
|  | 62 | +  seed, seed2 = ( | 
|  | 63 | +      torch.randint( | 
|  | 64 | +          torch.iinfo(torch.int64).min, | 
|  | 65 | +          torch.iinfo(torch.int64).max, | 
|  | 66 | +          (2,), | 
|  | 67 | +          dtype=torch.int64, | 
|  | 68 | +          generator=generator, | 
|  | 69 | +      ) | 
|  | 70 | +      .detach() | 
|  | 71 | +      .numpy() | 
|  | 72 | +  ) | 
|  | 73 | + | 
|  | 74 | +  shape_ = stablehlo.constant( | 
|  | 75 | +      ir.DenseElementsAttr.get(np.array(shape, dtype=np.int32)) | 
|  | 76 | +  ) | 
|  | 77 | +  return stablehlo.CompositeOp( | 
|  | 78 | +      result=[ir.RankedTensorType.get(shape, elty)], | 
|  | 79 | +      inputs=[shape_], | 
|  | 80 | +      name=composite_name, | 
|  | 81 | +      composite_attributes=ir.DictAttr.get({ | 
|  | 82 | +          "seed": ir.IntegerAttr.get(ir.IntegerType.get_signless(64), seed), | 
|  | 83 | +          "seed2": ir.IntegerAttr.get(ir.IntegerType.get_signless(64), seed2), | 
|  | 84 | +      }), | 
|  | 85 | +      decomposition=decomp_name, | 
|  | 86 | +  ).results[0] | 
|  | 87 | + | 
|  | 88 | + | 
|  | 89 | +# Schema: | 
|  | 90 | +# - aten::rand(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, | 
|  | 91 | +#     Device? device=None, bool? pin_memory=None) -> Tensor | 
|  | 92 | +# - aten::rand.generator(SymInt[] size, *, Generator? generator, | 
|  | 93 | +#     ScalarType? dtype=None, Layout? layout=None, Device? device=None, | 
|  | 94 | +#     bool? pin_memory=None) -> Tensor | 
|  | 95 | +@registry.lower(torch.ops.aten.rand) | 
|  | 96 | +def _aten_rand( | 
|  | 97 | +    lctx: LoweringContext, | 
|  | 98 | +    size, | 
|  | 99 | +    generator=None, | 
|  | 100 | +    dtype=None, | 
|  | 101 | +    layout=torch.strided, | 
|  | 102 | +    device=None, | 
|  | 103 | +    pin_memory=False, | 
|  | 104 | +): | 
|  | 105 | +  return _random_lowering( | 
|  | 106 | +      lctx, | 
|  | 107 | +      size, | 
|  | 108 | +      generator, | 
|  | 109 | +      dtype, | 
|  | 110 | +      rand_tensor=torch.ops.aten.rand.generator( | 
|  | 111 | +          size, generator=generator, dtype=dtype | 
|  | 112 | +      ), | 
|  | 113 | +      composite_name="odml.random_uniform", | 
|  | 114 | +  ) | 
|  | 115 | + | 
|  | 116 | + | 
|  | 117 | +# Schema: | 
|  | 118 | +# - aten::randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, | 
|  | 119 | +#     Device? device=None, bool? pin_memory=None) -> Tensor | 
|  | 120 | +# - aten::randn.generator(SymInt[] size, *, Generator? generator, | 
|  | 121 | +#     ScalarType? dtype=None, Layout? layout=None, Device? device=None, | 
|  | 122 | +#     bool? pin_memory=None) -> Tensor | 
|  | 123 | +@registry.lower(torch.ops.aten.randn) | 
|  | 124 | +def _aten_randn( | 
|  | 125 | +    lctx: LoweringContext, | 
|  | 126 | +    size, | 
|  | 127 | +    generator=None, | 
|  | 128 | +    dtype=None, | 
|  | 129 | +    layout=torch.strided, | 
|  | 130 | +    device=None, | 
|  | 131 | +    pin_memory=False, | 
|  | 132 | +): | 
|  | 133 | +  return _random_lowering( | 
|  | 134 | +      lctx, | 
|  | 135 | +      size, | 
|  | 136 | +      generator, | 
|  | 137 | +      dtype, | 
|  | 138 | +      rand_tensor=torch.ops.aten.randn.generator( | 
|  | 139 | +          size, generator=generator, dtype=dtype | 
|  | 140 | +      ), | 
|  | 141 | +      composite_name="odml.random_standard_normal", | 
|  | 142 | +  ) | 
0 commit comments