Skip to content

Commit 6d164cc

Browse files
chunnienccopybara-github
authored andcommitted
Optimize Flatbuffer export for numeric DenseElementsAttr and DenseResourceElementsAttr.
PiperOrigin-RevId: 871622371
1 parent 0805af3 commit 6d164cc

4 files changed

Lines changed: 29 additions & 81 deletions

File tree

litert_torch/_config.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -74,24 +74,16 @@ def layout_optimize_partitioner(self, value: str) -> None:
7474
os.environ["LAYOUT_OPTIMIZE_PARTITIONER"] = str(value).upper()
7575

7676
@property
77-
def lazy_constant_numel_threshold(self) -> int:
77+
def resource_constant_numel_threshold(self) -> int:
7878
"""The threshold for the number of elements in a constant to be eligible to be lazily loaded during lightweight conversion."""
7979
default = 1024 * 1024 # 1MB
80-
return _get_int_env_var("LAZY_CONSTANT_NUMEL_THRESHOLD", default=default)
80+
return _get_int_env_var(
81+
"RESOURCE_CONSTANT_NUMEL_THRESHOLD", default=default
82+
)
8183

82-
@lazy_constant_numel_threshold.setter
83-
def lazy_constant_numel_threshold(self, value: int) -> None:
84-
os.environ["LAZY_CONSTANT_NUMEL_THRESHOLD"] = str(value)
85-
86-
@property
87-
def lazy_constant_getter_chunk_size(self) -> int:
88-
"""The chunk size for the lazy constant getter during lightweight conversion."""
89-
default = 32 * 1024 * 1024 # 32MB
90-
return _get_int_env_var("LAZY_CONSTANT_GETTER_CHUNK_SIZE", default=default)
91-
92-
@lazy_constant_getter_chunk_size.setter
93-
def lazy_constant_getter_chunk_size(self, value: int) -> None:
94-
os.environ["LAZY_CONSTANT_GETTER_CHUNK_SIZE"] = str(value)
84+
@resource_constant_numel_threshold.setter
85+
def resource_constant_numel_threshold(self, value: int) -> None:
86+
os.environ["RESOURCE_CONSTANT_NUMEL_THRESHOLD"] = str(value)
9587

9688
@property
9789
def show_progress(self) -> bool:

litert_torch/_convert/litert_converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def exported_programs_to_flatbuffer(
120120

121121
ir_context = backend.export_utils.create_ir_context()
122122
cross_program_inline_consts_ctx = inline_consts_lib.InlineConstsContext(
123-
enable_lazy_constants=lightweight_conversion,
123+
enable_resource_constants=lightweight_conversion,
124124
)
125125

126126
lowered_programs = []

litert_torch/backend/inline_consts.py

Lines changed: 20 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from ai_edge_litert.mlir.dialects import arith
2525
import numpy as np
2626
import torch
27-
from ai_edge_litert.mlir._mlir_libs import converter_api_ext
2827

2928
config = _config.config
3029

@@ -33,7 +32,7 @@
3332
class InlineConstsContext(lowerings.context.LoweringContextPlugin):
3433
"""The context object for inlining constants."""
3534

36-
enable_lazy_constants: bool = False
35+
enable_resource_constants: bool = False
3736
constant_cache: dict[int, ir.Attribute] = dataclasses.field(
3837
default_factory=dict
3938
)
@@ -63,37 +62,12 @@ def _tensor_fingerprint(tensor: torch.Tensor) -> int:
6362
def _tensor_to_mlir_compatible_array(tensor: torch.Tensor) -> np.ndarray:
6463
"""Converts a tensor to a numpy array that is compatible with MLIR contiguity and endianness."""
6564
if hasattr(tensor, 'detach'):
66-
arr = tensor.detach().cpu().numpy()
65+
arr = tensor.contiguous().detach().cpu().numpy()
6766
else:
6867
arr = np.array(tensor)
6968

70-
if arr.dtype == bool or arr.dtype == np.bool_:
71-
# packbits returns uint8; bitorder='little' is crucial for MLIR
72-
packed = np.packbits(arr, axis=None, bitorder='little')
73-
return packed
74-
75-
target_dtype = {
76-
# Floating point
77-
np.float16: '<f2',
78-
np.float32: '<f4',
79-
np.float64: '<f8',
80-
# Signed Integers
81-
np.int8: '<i1',
82-
np.int16: '<i2',
83-
np.int32: '<i4',
84-
np.int64: '<i8',
85-
# Unsigned Integers
86-
np.uint8: '<u1',
87-
np.uint16: '<u2',
88-
np.uint32: '<u4',
89-
np.uint64: '<u8',
90-
}.get(arr.dtype.type)
91-
92-
if target_dtype is None:
93-
raise TypeError(f'Unsupported dtype for MLIR conversion: {arr.dtype}')
94-
95-
# Ensure C-contiguity and the specific bit-width/endianness
96-
return np.ascontiguousarray(arr, dtype=target_dtype)
69+
# Ensure C-contiguity
70+
return np.ascontiguousarray(arr)
9771

9872

9973
def _get_tensor_uniform_value(tensor: torch.Tensor):
@@ -134,7 +108,7 @@ def _clamp_inf_values(tensor: torch.Tensor):
134108
"""Clamps a tensor to the min/max value for float tensors."""
135109
if torch.is_floating_point(tensor):
136110
info = torch.finfo(tensor.dtype)
137-
tensor = torch.clamp(tensor, info.min, info.max)
111+
tensor.clamp_(min=info.min, max=info.max)
138112
return tensor
139113

140114

@@ -161,6 +135,7 @@ def tensor_lowering_placeholder_lowering(
161135
):
162136
"""Lower the placeholder function to a constant op."""
163137
const_ctx = InlineConstsContext.get(lctx)
138+
x = x.contiguous().detach().cpu()
164139

165140
x_fingerprint = _tensor_fingerprint(x)
166141
elty = lowering_utils.torch_dtype_to_ir_element_type(x.dtype)
@@ -171,52 +146,33 @@ def tensor_lowering_placeholder_lowering(
171146
if cached_attr is not None:
172147
return _build_const(cached_attr, tensor_type)
173148

174-
use_lazy_attr = const_ctx.enable_lazy_constants
175-
if x.dtype not in [torch.float32]:
176-
use_lazy_attr = False
149+
use_resource_attr = const_ctx.enable_resource_constants
150+
if x.dtype not in [torch.float32, torch.int32]:
151+
use_resource_attr = False
177152

178153
# If the tensor is too small, just use a dense elements attr.
179-
if x.numel() * x.element_size() < config.lazy_constant_numel_threshold:
180-
use_lazy_attr = False
154+
if x.numel() * x.element_size() < config.resource_constant_numel_threshold:
155+
use_resource_attr = False
181156

182-
# If not using lazy attr, clamp inf values to the min/max value of the
183-
# tensor's dtype. Otherwise, rely on the bytes getter to clamp values
184-
# lazily.
185-
if not use_lazy_attr:
186-
x = _clamp_inf_values(x)
157+
x = _clamp_inf_values(x)
187158

188159
# If the tensor is uniform, use a splat constant.
189160
uniform_value = _get_tensor_uniform_value(x)
190161
if uniform_value is not None:
191-
use_lazy_attr = False
162+
use_resource_attr = False
192163

193164
if uniform_value is not None:
194165
attr = lowering_utils.splat_attr(
195166
uniform_value,
196167
tensor_type.element_type,
197168
tensor_type.shape,
198169
)
199-
elif use_lazy_attr:
200-
201-
def chunk_iterator_factory():
202-
nonlocal x
203-
element_size = x.element_size()
204-
elements_per_chunk = (
205-
config.lazy_constant_getter_chunk_size // element_size
206-
)
207-
208-
# x.view(-1) is a metadata-only operation (0 bytes allocated)
209-
flat_x = x.view(-1)
210-
numel = flat_x.numel()
211-
212-
for i in range(0, numel, elements_per_chunk):
213-
chunk = flat_x[i : i + elements_per_chunk]
214-
chunk = _clamp_inf_values(chunk)
215-
chunk_data = _tensor_to_mlir_compatible_array(chunk).tobytes()
216-
yield chunk_data
217-
218-
attr = converter_api_ext.get_py_chunked_callback_resource_attr(
219-
tensor_type, chunk_iterator_factory
170+
elif use_resource_attr:
171+
arr = _tensor_to_mlir_compatible_array(x)
172+
attr = ir.DenseResourceElementsAttr.get_from_buffer(
173+
memoryview(arr),
174+
f'TENSOR_{x_fingerprint}',
175+
tensor_type,
220176
)
221177
else:
222178
arr = _tensor_to_mlir_compatible_array(x)
@@ -227,7 +183,7 @@ def chunk_iterator_factory():
227183

228184

229185
def inline_consts(exported_program: torch.export.ExportedProgram) -> None:
230-
"""Inlines exported program's constant inputs by replacing with lazy_tensor_placeholder."""
186+
"""Inlines exported program's constant inputs by replacing with resource_tensor_placeholder."""
231187
flat_user_inputs, _ = exported_program._get_flat_args_with_check(
232188
*exported_program.example_inputs
233189
)

litert_torch/generative/export_hf/core/export_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def export_text_prefill_decode_model(
261261
start_time = time.perf_counter()
262262

263263
print('Converting model...')
264-
lrt_model = converter.convert(strict_export=False)
264+
lrt_model = converter.convert(lightweight_conversion=True, strict_export=False)
265265
print('Converting model done.')
266266

267267
lrt_model = mu_pass_lib.update_model(lrt_model)

0 commit comments

Comments
 (0)