Skip to content

Commit 299ff3a

Browse files
[llm] Add IRPA transcoding. (#633)
Primary feature is the addition of `save` and `load` on the `Dataset` class. Previously, it was only possible to create one of these from a GGUF file. Now, we can load from GGUF, save to IRPA, do arbitrary transformations and update, etc. The rest is the machinery to manage this: * Layouts types can now be registered and have stable serialization. * InferenceTensor types can now be registered and have stable serialization. * Generic PlanarQuantizedTensor which can generically represent the unpacked form of any QuantizedTensor when serializing to IRPA. By default, QuantizedTensors are saved planarized. * The core ParameterArchive classes were completed to faithfully roundtrip PyTorch tensors. * There are still things in the CLI that refer to "gguf", but these now operate on any supported archive format (GGUF or IRPA). Future enhancements include the ability to support incremental IRPA construction with a last-wins/merged approach. This would let arbitrary Theta transformations to be done and appended to new segments of the files. Probably a really useful quantizer feature. While not done here, this facility makes it pretty trivial to write quantizers and other things since it moves this from bespoke/file-format handling like they are usually based on. Instead, it is just high level transformations on the Theta collection.
1 parent b9c9201 commit 299ff3a

24 files changed

+1222
-51
lines changed

core/shark_turbine/aot/exporter.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ def mlir_module(self) -> Operation:
6868
"""Gets the MLIR module resulting from the last compilation phase."""
6969
return CompiledModule.get_mlir_module(self.compiled_module)
7070

71+
def verify(self):
72+
"""Runs the verifier on the module, raising an exception on failure."""
73+
self.mlir_module.verify()
74+
7175
def print_readable(self, large_elements_limit: int = 50):
7276
"""Prints a human readable version of the current compilation IR."""
7377
self.mlir_module.print(large_elements_limit=large_elements_limit)

core/shark_turbine/aot/params.py

Lines changed: 152 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66

77
from typing import Iterator, List, Optional, Set, Tuple, Union
88

9-
from dataclasses import dataclass
9+
import json
1010
from pathlib import Path
11+
import warnings
1112

13+
import numpy as np
1214
import torch
1315
import torch.nn as nn
1416

@@ -26,6 +28,7 @@
2628
"externalize_module_parameters",
2729
"save_module_parameters",
2830
"ParameterArchive",
31+
"ParameterArchiveEntry",
2932
"ParameterArchiveBuilder",
3033
]
3134

@@ -46,6 +49,64 @@ def externalize_module_parameters(
4649
trait.set(tensor)
4750

4851

52+
################################################################################
53+
# Metadata
54+
################################################################################
55+
56+
_dtype_to_name: dict[torch.dtype, str] = {
57+
torch.float32: "float32",
58+
torch.float64: "float64",
59+
torch.complex64: "complex64",
60+
torch.complex128: "complex128",
61+
torch.float16: "float16",
62+
torch.bfloat16: "bfloat16",
63+
torch.float8_e4m3fn: "float8_e4m3fn",
64+
torch.float8_e4m3fnuz: "float8_e4m3fnuz",
65+
torch.float8_e5m2: "float8_e5m2",
66+
torch.float8_e5m2fnuz: "float8_e5m2fnuz",
67+
torch.int8: "int8",
68+
torch.int16: "int16",
69+
torch.int32: "int32",
70+
torch.int64: "int64",
71+
torch.uint16: "uint16",
72+
torch.uint32: "uint32",
73+
torch.uint64: "uint64",
74+
torch.uint8: "uint8",
75+
torch.bool: "bool",
76+
}
77+
78+
_name_to_dtype: dict[str, torch.dtype] = {v: k for k, v in _dtype_to_name.items()}
79+
80+
_metadata_prefix = "PYTORCH:"
81+
82+
83+
def _make_tensor_metadata(t: torch.Tensor) -> str:
84+
"""Makes a tensor metadata blob that can be used to reconstruct the tensor."""
85+
dtype = t.dtype
86+
try:
87+
dtype_name = _dtype_to_name[dtype]
88+
except KeyError:
89+
dtype_name = "unknown"
90+
warnings.warn(
91+
f"Unknown dtype saving params: {dtype} (missing entry in params._dtype_to_name)"
92+
)
93+
dtype_desc = {
94+
"class_name": type(dtype).__name__,
95+
"is_complex": dtype.is_complex,
96+
"is_floating_point": dtype.is_floating_point,
97+
"is_signed": dtype.is_signed,
98+
"itemsize": dtype.itemsize,
99+
}
100+
d = {
101+
"type": "Tensor",
102+
"dtype": dtype_name,
103+
"shape": list(t.shape),
104+
"dtype_desc": dtype_desc,
105+
}
106+
encoded = f"{_metadata_prefix}{json.dumps(d)}"
107+
return encoded
108+
109+
49110
################################################################################
50111
# Parameter archives save/load
51112
################################################################################
@@ -63,6 +124,73 @@ def save_module_parameters(
63124
builder.save(file_path)
64125

65126

127+
class ParameterArchiveEntry:
128+
"""Wraps a raw ParameterIndexEntry with additional helpers."""
129+
130+
def __init__(self, raw: ParameterIndexEntry):
131+
self.raw = raw
132+
133+
@property
134+
def key(self) -> str:
135+
return self.raw.key
136+
137+
def as_flat_tensor(self) -> torch.Tensor:
138+
"""Accesses the contents as a uint8 flat tensor.
139+
140+
If it is a splat, then the tensor will be a view of the splat pattern.
141+
142+
Raises a ValueError on unsupported entries.
143+
"""
144+
if self.raw.is_file:
145+
wrapper = np.array(self.raw.file_view, copy=False)
146+
elif self.raw.is_splat:
147+
wrapper = np.array(self.raw.splat_pattern, copy=True)
148+
else:
149+
raise ValueError(f"Unsupported ParameterIndexEntry: {self.raw}")
150+
151+
return torch.from_numpy(wrapper)
152+
153+
def as_tensor(self) -> torch.Tensor:
154+
"""Returns a tensor viewed with appropriate shape/dtype from metadata.
155+
156+
Raises a ValueError if unsupported.
157+
"""
158+
# Decode metadata.
159+
metadata = self.raw.metadata.decode()
160+
if not metadata.startswith(_metadata_prefix):
161+
raise ValueError(
162+
f"No metadata for parameter entry {self.key}: Cannot convert to tensor"
163+
)
164+
metadata = metadata[len(_metadata_prefix) :]
165+
d = json.loads(metadata)
166+
try:
167+
type_name = d["type"]
168+
if d["type"] != "Tensor":
169+
raise ValueError(
170+
f"Metadata for parameter entry {self.key} is not a Tensor ('{type_name}')"
171+
)
172+
dtype_name = d["dtype"]
173+
shape = d["shape"]
174+
except KeyError as e:
175+
raise ValueError(f"Bad metadata for parameter entry {self.key}") from e
176+
177+
# Unpack/validate.
178+
try:
179+
dtype = _name_to_dtype[dtype_name]
180+
except KeyError:
181+
raise ValueError(f"Unknown dtype name '{dtype_name}'")
182+
try:
183+
shape = [int(d) for d in shape]
184+
except ValueError as e:
185+
raise ValueError(f"Illegal shape for parameter entry {self.key}") from e
186+
187+
t = self.as_flat_tensor()
188+
return t.view(dtype=dtype).view(shape)
189+
190+
def __repr__(self):
191+
return f"ParameterArchiveEntry({self.raw}, metadata={self.raw.metadata})"
192+
193+
66194
class ParameterArchive:
67195
"""Allows access to a parameter archive as CPU tensors.
68196
@@ -71,11 +199,16 @@ class ParameterArchive:
71199
"""
72200

73201
def __init__(
74-
self, file_path: Optional[Union[str, Path]] = None, *, mmap: bool = True
202+
self,
203+
file_path: Optional[Union[str, Path]] = None,
204+
*,
205+
mmap: bool = True,
206+
readable: bool = True,
207+
writable: bool = False,
75208
):
76209
self._index = ParameterIndex()
77210
if file_path is not None:
78-
self.load(file_path, mmap=mmap)
211+
self.load(file_path, mmap=mmap, readable=readable, writable=writable)
79212

80213
def load(
81214
self,
@@ -94,8 +227,12 @@ def load(
94227
def index(self) -> ParameterIndex:
95228
return self._index
96229

97-
def items(self) -> List[Tuple[str, ParameterIndexEntry]]:
98-
return self._index.items()
230+
def items(self) -> List[Tuple[str, ParameterArchiveEntry]]:
231+
"""Returns the items in the archive.
232+
233+
Note that there can be duplicates if the archive was constructed that way.
234+
"""
235+
return [(k, ParameterArchiveEntry(v)) for k, v in self._index.items()]
99236

100237
def __repr__(self):
101238
return repr(self._index)
@@ -113,14 +250,22 @@ def save(self, file_path: Union[str, Path]):
113250

114251
def add_tensor(self, name: str, tensor: torch.Tensor):
115252
"""Adds an named tensor to the archive."""
116-
host_array = tensor.detach().cpu().contiguous().numpy()
117-
self._index.add_buffer(name, host_array)
253+
flat_array = tensor.detach().flatten().contiguous().cpu().view(torch.uint8)
254+
host_array = flat_array.numpy()
255+
self._index.add_buffer(name, host_array, metadata=_make_tensor_metadata(tensor))
118256

119257
def add_module(self, module: nn.Module, *, prefix: str = ""):
120258
"""Adds all parameters and persistent buffers from a module hierarchy."""
121259
for name, t in _yield_saveable_tensors(module, prefix=prefix):
122260
self.add_tensor(name, t)
123261

262+
def add_blob(self, key: str, blob):
263+
"""Adds a raw blob to the index.
264+
265+
The blob must be interpretable as a buffer.
266+
"""
267+
self._index.add_buffer(key, blob)
268+
124269

125270
def _yield_saveable_tensors(
126271
module: nn.Module, *, prefix: str = ""

core/shark_turbine/dynamo/type_conversion.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,25 @@
99
Note that there are ad-hoc type conversions spread around a bit, and we
1010
should consolidate them here.
1111
"""
12-
from typing import List
12+
from typing import List, Optional
1313

1414
import functools
1515
import re
1616

17-
from iree.compiler.ir import (
17+
from ..support.ir_imports import (
18+
tensor_d,
1819
Context,
1920
F64Type,
2021
IntegerType,
2122
RankedTensorType,
2223
ShapedType,
23-
Type as IrType,
24+
IrType,
2425
Location,
2526
Operation,
2627
Value,
2728
)
2829

30+
2931
# Match an overall torch type declaration. Groups:
3032
# 1. Local name (int, float, vtensor)
3133
# 2. Parameter block ("<...>"), including the delimitters
@@ -103,11 +105,15 @@ def convert_torch_element_type_to_native(
103105
return torch_type
104106

105107
def materialize_native_to_torch(
106-
self, native_value: Value, torch_type: IrType
108+
self, native_value: Value, torch_type: IrType, *, static_info_cast: bool = False
107109
) -> Value:
108110
native_type = native_value.type
109111
if RankedTensorType.isinstance(native_type):
110112
# Convert to vtensor.
113+
if static_info_cast:
114+
required_native_type = self.torch_type_to_native(torch_type)
115+
if required_native_type != native_type:
116+
native_value = tensor_d.cast(required_native_type, native_value)
111117
return Operation.create(
112118
"torch_c.from_builtin_tensor",
113119
results=[torch_type],
@@ -138,15 +144,23 @@ def materialize_native_to_torch(
138144
f"Unsupported native->torch ABI type conversion: {native_type} -> {torch_type}"
139145
)
140146

141-
def materialize_torch_to_native(self, torch_value: Value) -> Value:
147+
def materialize_torch_to_native(
148+
self, torch_value: Value, *, static_info_cast_to: Optional[IrType] = None
149+
) -> Value:
142150
native_type = self.torch_type_to_native(torch_value.type)
143151
if RankedTensorType.isinstance(native_type):
144152
# Convert to vtensor.
145-
return Operation.create(
153+
builtin_tensor_value = Operation.create(
146154
"torch_c.to_builtin_tensor",
147155
results=[native_type],
148156
operands=[torch_value],
149157
).result
158+
# Detect type difference and assume a static cast is needed.
159+
if static_info_cast_to is not None and static_info_cast_to != native_type:
160+
builtin_tensor_value = tensor_d.cast(
161+
static_info_cast_to, builtin_tensor_value
162+
)
163+
return builtin_tensor_value
150164
elif IntegerType.isinstance(native_type):
151165
# Convert to !torch.int
152166
int_type = IntegerType(native_type)

core/shark_turbine/transforms/general/custom_op_expansion.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
from ...support.ir_imports import (
3434
Block,
35+
IrType,
3536
InsertionPoint,
3637
OpResult,
3738
Operation,
@@ -262,7 +263,10 @@ def __init__(
262263
if not desc.is_list:
263264
if arity == 1:
264265
arg_bindings.append(
265-
type_converter.materialize_torch_to_native(operand)
266+
type_converter.materialize_torch_to_native(
267+
operand,
268+
static_info_cast_to=IrType.parse(desc.mlir_type_asm),
269+
)
266270
)
267271
else:
268272
arg_bindings.append(None)
@@ -297,7 +301,9 @@ def yield_results(self, *results: Value):
297301
for new_result, old_result in zip(results, torch_op_results):
298302
torch_type = old_result.type
299303
new_result = self.type_converter.materialize_native_to_torch(
300-
new_result, torch_type
304+
new_result,
305+
torch_type,
306+
static_info_cast=True,
301307
)
302308
old_result.replace_all_uses_with(new_result)
303309
self.yielded = True

core/tests/aot/params_test.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,6 @@
1212
import torch
1313
import torch.nn as nn
1414

15-
from iree.runtime import (
16-
ParameterIndex,
17-
ParameterProvider,
18-
)
19-
2015
from shark_turbine.aot import (
2116
export,
2217
externalize_module_parameters,
@@ -50,8 +45,10 @@ def testCreateArchive(self):
5045
# lock the file for an arbitrary duration.
5146
archive = ParameterArchive(file_path, mmap=False)
5247
items = dict(archive.items())
53-
self.assertIn("classifier.weight", items)
54-
self.assertIn("classifier.bias", items)
48+
weight = items["classifier.weight"].as_tensor()
49+
bias = items["classifier.bias"].as_tensor()
50+
torch.testing.assert_close(weight, m.classifier.weight)
51+
torch.testing.assert_close(bias, m.classifier.bias)
5552
finally:
5653
file_path.unlink()
5754

@@ -65,8 +62,10 @@ def testCreateArchiveWithPrefixScope(self):
6562
# lock the file for an arbitrary duration.
6663
archive = ParameterArchive(file_path, mmap=False)
6764
items = dict(archive.items())
68-
self.assertIn("foobar.model.classifier.weight", items)
69-
self.assertIn("foobar.model.classifier.bias", items)
65+
weight = items["foobar.model.classifier.weight"].as_tensor()
66+
bias = items["foobar.model.classifier.bias"].as_tensor()
67+
torch.testing.assert_close(weight, m.classifier.weight)
68+
torch.testing.assert_close(bias, m.classifier.bias)
7069
finally:
7170
file_path.unlink()
7271

core/tests/transforms/general/custom_op_expansion_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,18 @@ def setUpClass(cls):
3434
def testTensorArgReturn(self):
3535
m = self.run_test_case("custom_op_simple.mlir")
3636
m_asm = str(m)
37+
print(m_asm)
3738
self.assertNotIn("torch.operator", m_asm)
3839
self.assertIn(
3940
"%0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[97,8],f32> -> tensor<97x8xf32>",
4041
m_asm,
4142
)
43+
# TODO: Upgrade to a FileCheck style test so we can pattern match that
44+
# the casts are inserted properly.
4245
self.assertIn(
43-
"%1 = torch_c.from_builtin_tensor %0 : tensor<97x8xf32> -> !torch.vtensor<[97,8],f32>",
46+
"%1 = torch_c.from_builtin_tensor %cast_0 : tensor<97x8xf32> -> !torch.vtensor<[97,8],f32>",
4447
m_asm,
4548
)
46-
print(m_asm)
4749

4850
def testStringAttrArg(self):
4951
global _TEST_STRING_ATTR

0 commit comments

Comments
 (0)