2424from ai_edge_litert .mlir .dialects import arith
2525import numpy as np
2626import torch
27- from ai_edge_litert .mlir ._mlir_libs import converter_api_ext
2827
2928config = _config .config
3029
3332class InlineConstsContext (lowerings .context .LoweringContextPlugin ):
3433 """The context object for inlining constants."""
3534
36- enable_lazy_constants : bool = False
35+ enable_resource_constants : bool = False
3736 constant_cache : dict [int , ir .Attribute ] = dataclasses .field (
3837 default_factory = dict
3938 )
@@ -63,37 +62,12 @@ def _tensor_fingerprint(tensor: torch.Tensor) -> int:
6362def _tensor_to_mlir_compatible_array (tensor : torch .Tensor ) -> np .ndarray :
6463 """Converts a tensor to a numpy array that is compatible with MLIR contiguity and endianness."""
6564 if hasattr (tensor , 'detach' ):
66- arr = tensor .detach ().cpu ().numpy ()
65+ arr = tensor .contiguous (). detach ().cpu ().numpy ()
6766 else :
6867 arr = np .array (tensor )
6968
70- if arr .dtype == bool or arr .dtype == np .bool_ :
71- # packbits returns uint8; bitorder='little' is crucial for MLIR
72- packed = np .packbits (arr , axis = None , bitorder = 'little' )
73- return packed
74-
75- target_dtype = {
76- # Floating point
77- np .float16 : '<f2' ,
78- np .float32 : '<f4' ,
79- np .float64 : '<f8' ,
80- # Signed Integers
81- np .int8 : '<i1' ,
82- np .int16 : '<i2' ,
83- np .int32 : '<i4' ,
84- np .int64 : '<i8' ,
85- # Unsigned Integers
86- np .uint8 : '<u1' ,
87- np .uint16 : '<u2' ,
88- np .uint32 : '<u4' ,
89- np .uint64 : '<u8' ,
90- }.get (arr .dtype .type )
91-
92- if target_dtype is None :
93- raise TypeError (f'Unsupported dtype for MLIR conversion: { arr .dtype } ' )
94-
95- # Ensure C-contiguity and the specific bit-width/endianness
96- return np .ascontiguousarray (arr , dtype = target_dtype )
69+ # Ensure C-contiguity
70+ return np .ascontiguousarray (arr )
9771
9872
9973def _get_tensor_uniform_value (tensor : torch .Tensor ):
@@ -134,7 +108,7 @@ def _clamp_inf_values(tensor: torch.Tensor):
134108 """Clamps a tensor to the min/max value for float tensors."""
135109 if torch .is_floating_point (tensor ):
136110 info = torch .finfo (tensor .dtype )
137- tensor = torch . clamp ( tensor , info .min , info .max )
111+ tensor . clamp_ ( min = info .min , max = info .max )
138112 return tensor
139113
140114
@@ -161,6 +135,7 @@ def tensor_lowering_placeholder_lowering(
161135):
162136 """Lower the placeholder function to a constant op."""
163137 const_ctx = InlineConstsContext .get (lctx )
138+ x = x .contiguous ().detach ().cpu ()
164139
165140 x_fingerprint = _tensor_fingerprint (x )
166141 elty = lowering_utils .torch_dtype_to_ir_element_type (x .dtype )
@@ -171,52 +146,33 @@ def tensor_lowering_placeholder_lowering(
171146 if cached_attr is not None :
172147 return _build_const (cached_attr , tensor_type )
173148
174- use_lazy_attr = const_ctx .enable_lazy_constants
175- if x .dtype not in [torch .float32 ]:
176- use_lazy_attr = False
149+ use_resource_attr = const_ctx .enable_resource_constants
150+ if x .dtype not in [torch .float32 , torch . int32 ]:
151+ use_resource_attr = False
177152
178153 # If the tensor is too small, just use a dense elements attr.
179- if x .numel () * x .element_size () < config .lazy_constant_numel_threshold :
180- use_lazy_attr = False
154+ if x .numel () * x .element_size () < config .resource_constant_numel_threshold :
155+ use_resource_attr = False
181156
182- # If not using lazy attr, clamp inf values to the min/max value of the
183- # tensor's dtype. Otherwise, rely on the bytes getter to clamp values
184- # lazily.
185- if not use_lazy_attr :
186- x = _clamp_inf_values (x )
157+ x = _clamp_inf_values (x )
187158
188159 # If the tensor is uniform, use a splat constant.
189160 uniform_value = _get_tensor_uniform_value (x )
190161 if uniform_value is not None :
191- use_lazy_attr = False
162+ use_resource_attr = False
192163
193164 if uniform_value is not None :
194165 attr = lowering_utils .splat_attr (
195166 uniform_value ,
196167 tensor_type .element_type ,
197168 tensor_type .shape ,
198169 )
199- elif use_lazy_attr :
200-
201- def chunk_iterator_factory ():
202- nonlocal x
203- element_size = x .element_size ()
204- elements_per_chunk = (
205- config .lazy_constant_getter_chunk_size // element_size
206- )
207-
208- # x.view(-1) is a metadata-only operation (0 bytes allocated)
209- flat_x = x .view (- 1 )
210- numel = flat_x .numel ()
211-
212- for i in range (0 , numel , elements_per_chunk ):
213- chunk = flat_x [i : i + elements_per_chunk ]
214- chunk = _clamp_inf_values (chunk )
215- chunk_data = _tensor_to_mlir_compatible_array (chunk ).tobytes ()
216- yield chunk_data
217-
218- attr = converter_api_ext .get_py_chunked_callback_resource_attr (
219- tensor_type , chunk_iterator_factory
170+ elif use_resource_attr :
171+ arr = _tensor_to_mlir_compatible_array (x )
172+ attr = ir .DenseResourceElementsAttr .get_from_buffer (
173+ memoryview (arr ),
174+ f'TENSOR_{ x_fingerprint } ' ,
175+ tensor_type ,
220176 )
221177 else :
222178 arr = _tensor_to_mlir_compatible_array (x )
@@ -227,7 +183,7 @@ def chunk_iterator_factory():
227183
228184
229185def inline_consts (exported_program : torch .export .ExportedProgram ) -> None :
230- """Inlines exported program's constant inputs by replacing with lazy_tensor_placeholder ."""
186+ """Inlines exported program's constant inputs by replacing with resource_tensor_placeholder ."""
231187 flat_user_inputs , _ = exported_program ._get_flat_args_with_check (
232188 * exported_program .example_inputs
233189 )
0 commit comments