Skip to content

Commit f525407

Browse files
committed
Added interface methods which expose a few basic packing and alignment facts
1 parent bb344d5 commit f525407

File tree

2 files changed

+66
-5
lines changed

2 files changed

+66
-5
lines changed

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,18 @@ def Builtin_Float8E5M2 : Builtin_FloatType<"Float8E5M2", "f8E5M2",
132132

133133
/// Get the storage type as a string.
134134
std::string getStorageType() const { return "f8E5M2"; }
135+
136+
/// Check if this 8-bit floating point type uses packed representation.
137+
bool isPacked() const { return false; }
138+
139+
/// Get the logical bit width per value for this 8-bit floating point type.
140+
unsigned getLogicalBitWidth() const { return 8; }
141+
142+
/// Get the number of logical elements that fit in one byte for this 8-bit floating point type.
143+
unsigned getElementsPerByte() const { return 1; }
144+
145+
/// Get the preferred alignment in bytes for this 8-bit floating point type.
146+
std::optional<unsigned> getPreferredAlignmentBytes() const { return std::nullopt; }
135147
}];
136148
}
137149

@@ -191,6 +203,18 @@ def Builtin_Float8E4M3FN : Builtin_FloatType<"Float8E4M3FN", "f8E4M3FN",
191203

192204
/// Get the storage type as a string.
193205
std::string getStorageType() const { return "f8E4M3FN"; }
206+
207+
/// Check if this 8-bit floating point type uses packed representation.
208+
bool isPacked() const { return false; }
209+
210+
/// Get the logical bit width per value for this 8-bit floating point type.
211+
unsigned getLogicalBitWidth() const { return 8; }
212+
213+
/// Get the number of logical elements that fit in one byte for this 8-bit floating point type.
214+
unsigned getElementsPerByte() const { return 1; }
215+
216+
/// Get the preferred alignment in bytes for this 8-bit floating point type.
217+
std::optional<unsigned> getPreferredAlignmentBytes() const { return std::nullopt; }
194218
}];
195219
}
196220

@@ -617,6 +641,18 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer",
617641
std::string getStorageType() const {
618642
return (isStorageSigned() ? "i" : "u") + std::to_string(getWidth());
619643
}
644+
645+
/// Check if this integer type uses packed representation.
646+
bool isPacked() const { return false; }
647+
648+
/// Get the logical bit width per value for this integer type.
649+
unsigned getLogicalBitWidth() const { return getWidth(); }
650+
651+
/// Get the number of logical elements that fit in one byte for this integer type.
652+
unsigned getElementsPerByte() const { return 1; }
653+
654+
/// Get the preferred alignment in bytes for this integer type.
655+
std::optional<unsigned> getPreferredAlignmentBytes() const { return std::nullopt; }
620656
}];
621657
}
622658

mlir/include/mlir/IR/QuantStorageTypeInterface.td

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ include "mlir/IR/OpBase.td"
66
def QuantStorageTypeInterface : TypeInterface<"QuantStorageTypeInterface"> {
77
let description = [{
88
Interface for types that can be used as storage types in Quant dialect.
9-
This interface provides methods to determine storage characteristics for quantization purposes.
9+
This interface provides methods to determine storage characteristics for quantization purposes,
10+
including packing behavior, and alignment requirements.
1011
}];
1112
let cppNamespace = "::mlir";
1213

@@ -18,25 +19,49 @@ def QuantStorageTypeInterface : TypeInterface<"QuantStorageTypeInterface"> {
1819
"bool", "isStorageSigned", (ins)>,
1920

2021
InterfaceMethod<[{
21-
Get the bit width of this integer type.
22+
Get the bit width of this type.
2223
Returns the number of bits used to store values of this type.
2324
}],
2425
"unsigned", "getStorageWidth", (ins)>,
2526

2627
InterfaceMethod<[{
27-
Get default minimum value for this integer type.
28+
Get default minimum value for this type.
2829
}],
2930
"int64_t", "getDefaultMinimum", (ins)>,
3031

3132
InterfaceMethod<[{
32-
Get default maximum value for this integer type.
33+
Get default maximum value for this type.
3334
}],
3435
"int64_t", "getDefaultMaximum", (ins)>,
3536

3637
InterfaceMethod<[{
3738
Get the storage type as a string.
3839
}],
39-
"std::string", "getStorageType", (ins)>
40+
"std::string", "getStorageType", (ins)>,
41+
42+
InterfaceMethod<[{
43+
Check if the storage type uses packed representation.
44+
Returns true if multiple values are packed into one byte (e.g., sub-byte types),
45+
false if value uses full byte.
46+
}],
47+
"bool", "isPacked", (ins)>,
48+
49+
InterfaceMethod<[{
50+
Get the logical bit width per value.
51+
For packed sub-byte types, this may differ from getStorageWidth().
52+
}],
53+
"unsigned", "getLogicalBitWidth", (ins)>,
54+
55+
InterfaceMethod<[{
56+
Get the number of logical elements that fit in one byte.
57+
For packed sub-byte types, this returns how many values can be stored per byte.
58+
}],
59+
"unsigned", "getElementsPerByte", (ins)>,
60+
61+
InterfaceMethod<[{
62+
Returns the preferred alignment for this type, in bytes.
63+
}],
64+
"std::optional<unsigned>", "getPreferredAlignmentBytes", (ins)>
4065
];
4166

4267
}

0 commit comments

Comments
 (0)