Skip to content

Commit 54ed073

Browse files
authored
1 parent f8c4137 commit 54ed073

22 files changed

+3718
-168
lines changed

WORKSPACE.bazel

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ workspace(name = "stablehlo")
1717

1818
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
1919

20-
LLVM_COMMIT = "6d4a0935c850ec3ddfc70c4ba97b98adc35c676e"
20+
LLVM_COMMIT = "fc44a4fcd3c54be927c15ddd9211aca1501633e7"
2121

22-
LLVM_SHA256 = "01d73796d7c614c809fe464b4beb1cfb5e275e805e7b985ef90c97fe22f01154"
22+
LLVM_SHA256 = "d228aebe5583c69c4e48fd7a8e149e3d22ee6dafaeae94009467143d32d9bfc4"
2323

2424
http_archive(
2525
name = "llvm-raw",

build_tools/github_actions/lint_version.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ VERSIONED_COMPAT_TEST="$COMPAT_TEST_BASE.$TEST_VERSION.mlir"
4343
VERSIONED_COMPAT_TEST_BC="$COMPAT_TEST_BASE.$TEST_VERSION.mlir.bc"
4444

4545
show_help() {
46-
HELP_URL="https://github.com/openxla/stablehlo/blob/main/docs/vhlo.md#add-versioned-serialization-test"
46+
HELP_URL="https://github.com/openxla/stablehlo/blob/main/docs/vhlo_checklist.md#4-add-versioned-serialization-test"
4747
echo "For details on creating versioned tests for a new minor version of"
4848
echo "StableHLO, see the instructions on:"
4949
echo "$HELP_URL"

build_tools/llvm_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
6d4a0935c850ec3ddfc70c4ba97b98adc35c676e
1+
fc44a4fcd3c54be927c15ddd9211aca1501633e7

docs/spec.md

Lines changed: 90 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ completely numeric to simplify generation of StableHLO programs.
8989

9090
```ebnf
9191
Type ::= ValueType | NonValueType
92-
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType
92+
ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType | BufferType
9393
NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType
9494
```
9595

@@ -229,6 +229,21 @@ TupleType ::= 'tuple' '<' TupleElementTypes '>'
229229
TupleElementTypes ::= [ValueType {',' ValueType}]
230230
```
231231

232+
**Buffer types** represent buffers. For example, in XLA, buffers are
233+
multidimensional arrays with consistent storage. Similar to **tensor types**,
234+
buffer types have a **shape** and an **element type**, where a shape represents
235+
non-negative or unknown **dimension sizes** in the ascending order of the
236+
corresponding **dimensions** (which are also called **axes**) numbered from `0`
237+
to `R-1`. The number of dimensions `R` is called **rank**. For example,
238+
`memref<2x3xf32>` is a buffer type with shape `2x3` and element type `f32`. It
239+
has two dimensions (or, in other words, two axes) - 0th dimension and 1st
240+
dimension - whose sizes are 2 and 3. Its rank is 2.
241+
242+
Buffers can be allocated using a `custom_call` to `CreateBuffer` or `Pin` and
243+
deallocated via a `custom_call` to `Unpin`. Only `custom_call` ops can read and
244+
write the content inside buffers. See [custom_call](#custom_call) for more
245+
detail.
246+
232247
**Tuple types** represent tuples, i.e. heterogeneous lists. Tuples are a legacy
233248
feature which only exists for compatibility with HLO. In HLO, tuples are
234249
used to represent variadic inputs and outputs. In StableHLO, variadic inputs and
@@ -2433,21 +2448,63 @@ the XLA compiler. In the future, we are planning to unify this metadata
24332448

24342449
#### Inputs
24352450

2436-
| Label | Name | Type |
2437-
|-------|-----------------------|---------------------------------------------------|
2438-
| (I1) | `inputs` | variadic number of values |
2439-
| (I2) | `call_target_name` | constant of type `string` |
2440-
| (I3) | `has_side_effect` | constant of type `i1` |
2441-
| (I4) | `backend_config` | constant of type `string` or attribute dictionary |
2442-
| (I5) | `api_version` | constant of type `si32` |
2443-
| (I6) | `called_computations` | variadic number of constants of type `string` |
2451+
| Label | Name | Type |
2452+
|-------|--------------------------|------------------------------------------------------------|
2453+
| (I1) | `inputs` | variadic number of values |
2454+
| (I2) | `call_target_name` | constant of type `string` |
2455+
| (I3) | `has_side_effect` | constant of type `i1` |
2456+
| (I4) | `backend_config` | constant of type `string` or attribute dictionary |
2457+
| (I5) | `api_version` | constant of type `si32` |
2458+
| (I6) | `called_computations` | variadic number of constants of type `string` |
2459+
| (I7) | `output_operand_aliases` | specify the aliasing parts in the outputs and operands |
24442460

24452461
#### Outputs
24462462

24472463
| Name | Type |
24482464
|-----------|---------------------------|
24492465
| `results` | variadic number of values |
24502466

2467+
### (XLA GPU Support) Special custom_call targets
2468+
2469+
There are three special `call_target_name` related to `buffer` types:
2470+
`CreateBuffer` creates an uninitialized `buffer`, `Pin` creates an initialized
2471+
`buffer` and `Unpin` deallocates a `buffer` and returns the content of the
2472+
`buffer`.
2473+
2474+
```mlir
2475+
%uninitialized_buffer = "stablehlo.custom_call"() {
2476+
call_target_name = "CreateBuffer",
2477+
api_version = 4 : i32,
2478+
} : () -> memref<4xf64>
2479+
2480+
%initialized_buffer = "stablehlo.custom_call"(%init_value) {
2481+
call_target_name = "Pin",
2482+
api_version = 4 : i32,
2483+
} : (tensor<4xf64>) -> memref<4xf64>
2484+
2485+
%dealloc_buffer = "stablehlo.custom_call"(%initialized_buffer) {
2486+
call_target_name = "Unpin",
2487+
api_version = 4 : i32,
2488+
} : (memref<4xf64>) -> tensor<4xf64>
2489+
2490+
```
2491+
2492+
### Alias
2493+
2494+
Some custom_call ops may require a part in the outputs and a part in the
2495+
operands to share the same memory. This can be expressed via
2496+
`output_operand_aliases`. An alias pair representation consists a list of output
2497+
tuple indices representing the output part, and an operand_index along with a
2498+
list of operand tuple indices representing the operand part. The list of output
2499+
or operand tuple indices is empty if the corresponding type is not a `tuple`
2500+
type, and can be arbitrarily long for an arbitrarily nested tuple type. This
2501+
is similar to [the XLA alias representation](https://www.tensorflow.org/xla/aliasing).
2502+
2503+
The output part and the input part in an alias pair must have the same type. For
2504+
custom_call ops that aren't call to `CreateBuffer`, `Pin` and `Unpin`, a
2505+
`buffer` operand can appear in at most one pair of alias, and a `buffer` output
2506+
must appear in one pair of alias.
2507+
24512508
#### Examples
24522509

24532510
```mlir
@@ -2458,6 +2515,16 @@ the XLA compiler. In the future, we are planning to unify this metadata
24582515
api_version = 4 : i32,
24592516
called_computations = [@foo]
24602517
} : (tensor<f64>) -> tensor<f64>
2518+
2519+
%updated_buffer = "stablehlo.custom_call"(%buffer) {
2520+
call_target_name = "Update",
2521+
api_version = 4 : i32,
2522+
output_operand_aliases = [
2523+
#stablehlo.output_operand_alias<output_tuple_indices = [],
2524+
operand_index = 0,
2525+
operand_tuple_indices = []>]
2526+
} : (memref<4xf64>) -> memref<4xf64>
2527+
24612528
```
24622529

24632530
### divide
@@ -3780,9 +3847,9 @@ Extracts element at `index` position of the `operand` tuple and produces a
37803847

37813848
#### Outputs
37823849

3783-
| Name | Type | Constraints |
3784-
|----------|--------------------|-------------|
3785-
| `result` | any supported type | (C2) |
3850+
| Name | Type | Constraints |
3851+
|----------|------------------------|-------------|
3852+
| `result` | any value | (C2) |
37863853

37873854
#### Constraints
37883855

@@ -6583,10 +6650,10 @@ Produces a `result` tuple from values `val`.
65836650
#### Examples
65846651

65856652
```mlir
6586-
// %val0: [1.0, 2.0]
6653+
// %val0: memref[1.0, 2.0]
65876654
// %val1: (3)
6588-
%result = "stablehlo.tuple"(%val0, %val1) : (tensor<2xf32>, tuple<tensor<i32>>) -> tuple<tensor<2xf32>, tuple<tensor<i32>>>
6589-
// %result: ([1.0, 2.0], (3))
6655+
%result = "stablehlo.tuple"(%val0, %val1) : (memref<2xf32>, tuple<tensor<i32>>) -> tuple<memref<2xf32>, tuple<tensor<i32>>>
6656+
// %result: (memref[1.0, 2.0], (3))
65906657
```
65916658

65926659
&nbsp;[More Examples](https://github.com/openxla/stablehlo/tree/main/stablehlo/tests/interpret/tuple_and_get_tuple_element.mlir)
@@ -6692,17 +6759,17 @@ The behavior of an infinite loop is TBD
66926759

66936760
#### Inputs
66946761

6695-
| Label | Name | Type | Constraints |
6696-
|-------|-----------|---------------------------------------------------------|-------------|
6697-
| (I1) | `operand` | variadic number of tensors, quantized tensors or tokens | (C1-C3) |
6698-
| (I2) | `cond` | function | (C1) |
6699-
| (I3) | `body` | function | (C2) |
6762+
| Label | Name | Type | Constraints |
6763+
|-------|-----------|-----------------------------------------|-------------|
6764+
| (I1) | `operand` | variadic number of values | (C1-C3) |
6765+
| (I2) | `cond` | function | (C1) |
6766+
| (I3) | `body` | function | (C2) |
67006767

67016768
#### Outputs
67026769

6703-
| Name | Type | Constraints |
6704-
|-----------|---------------------------------------------------------|-------------|
6705-
| `results` | variadic number of tensors, quantized tensors or tokens | (C3) |
6770+
| Name | Type | Constraints |
6771+
|-----------|-------------------------------------------------|-------------|
6772+
| `results` | variadic number of values | (C3) |
67066773

67076774
#### Constraints
67086775

stablehlo/dialect/Base.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -798,5 +798,35 @@ bool hasSingleBoundedDimension(Type type) {
798798
return numBoundedDims == 1 && numDynamicDims == 1;
799799
}
800800

801+
//===----------------------------------------------------------------------===//
802+
// Utils for type traversal.
803+
//===----------------------------------------------------------------------===//
804+
805+
namespace {
806+
LogicalResult mapOverLeafTypesImpl(
807+
Type type, function_ref<LogicalResult(Type type, ArrayRef<int64_t>)> fn,
808+
std::vector<int64_t>& indices) {
809+
if (!isa<TupleType>(type)) {
810+
return fn(type, indices);
811+
}
812+
813+
auto tupleType = cast<TupleType>(type);
814+
for (size_t i = 0; i < tupleType.size(); ++i) {
815+
indices.push_back(i);
816+
if (failed(mapOverLeafTypesImpl(tupleType.getType(i), fn, indices)))
817+
return failure();
818+
indices.pop_back();
819+
}
820+
821+
return success();
822+
}
823+
} // namespace
824+
825+
LogicalResult mapOverLeafTypes(
826+
Type type, function_ref<LogicalResult(Type, ArrayRef<int64_t>)> fn) {
827+
std::vector<int64_t> indices;
828+
return mapOverLeafTypesImpl(type, fn, indices);
829+
}
830+
801831
} // namespace hlo
802832
} // namespace mlir

stablehlo/dialect/Base.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,11 @@ void writeEnumAttribute(EnumTypeAttr val, DialectBytecodeWriter &writer) {
314314
mlir::Speculation::Speculatability getShapedSpeculatability(Operation *op,
315315
int64_t shapeCount);
316316

317+
// Applies `fn` to `type` if it is not a `tuple` type. Otherwise, applies `fn`
318+
// to each leaf type in the `tuple` type tree or until a `fn` returns failure.
319+
LogicalResult mapOverLeafTypes(
320+
Type type, function_ref<LogicalResult(Type, ArrayRef<int64_t>)> fn);
321+
317322
namespace OpTrait {
318323

319324
template <typename ConcreteType>

stablehlo/dialect/Base.td

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,16 +193,22 @@ def HLO_FloatOrQuantizedIntOrPerAxisQuantizedIntTensor : RankedTensorOf<[HLO_Flo
193193

194194
def HLO_ComplexTensor : RankedTensorOf<[HLO_Complex]>;
195195

196-
def HLO_Tuple : NestedTupleOf<[HLO_Tensor, HLO_PerAxisQuantizedIntTensor, HLO_Token]>;
196+
def HLO_Buffer : MemRefOf<[HLO_Float, HLO_Pred, HLO_Int, HLO_Complex, HLO_QuantizedInt]>;
197+
198+
def HLO_Tuple : NestedTupleOf<[HLO_Tensor, HLO_Buffer, HLO_PerAxisQuantizedIntTensor, HLO_Token]>;
197199

198200
def HLO_TensorOrToken : AnyTypeOf<[HLO_Tensor, HLO_Token]>;
199201

200202
def HLO_TensorOrPerAxisQuantizedTensorOrToken : AnyTypeOf<[HLO_Tensor, HLO_PerAxisQuantizedIntTensor, HLO_Token]>;
201203

204+
def HLO_TensorOrPerAxisQuantizedTensorOrTokenOrBuffer : AnyTypeOf<[HLO_TensorOrPerAxisQuantizedTensorOrToken, HLO_Buffer]>;
205+
202206
def HLO_TensorOrTokenOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Token, HLO_Tuple]>;
203207

204208
def HLO_TensorOrPerAxisQuantizedTensorOrTokenOrTuple : AnyTypeOf<[HLO_Tensor, HLO_PerAxisQuantizedIntTensor, HLO_Token, HLO_Tuple]>;
205209

210+
def HLO_TensorOrPerAxisQuantizedTensorOrTokenOrTupleOrBuffer : AnyTypeOf<[HLO_TensorOrPerAxisQuantizedTensorOrTokenOrTuple, HLO_Buffer]>;
211+
206212
def HLO_DimensionValue : AnyTypeOf<[Index, HLO_Int]>;
207213

208214
// Dynamic representation of a shape vector as a tensor.
@@ -227,9 +233,9 @@ def HLO_AnyFpOrComplexTensor : TensorOf<[HLO_Float, HLO_Complex]>;
227233

228234
def HLO_AnyPredOrIntTensor : TensorOf<[HLO_Pred, HLO_Int]>;
229235

230-
def HLO_AnyTuple : NestedTupleOf<[HLO_AnyTensor, HLO_Token]>;
236+
def HLO_CustomCallTuple : NestedTupleOf<[HLO_AnyTensor, HLO_Buffer, HLO_Token]>;
231237

232-
def HLO_CustomCallValue : AnyTypeOf<[HLO_AnyTensor, HLO_Token, HLO_AnyTuple]>;
238+
def HLO_CustomCallValue : AnyTypeOf<[HLO_AnyTensor, HLO_Buffer, HLO_Token, HLO_CustomCallTuple]>;
233239

234240
//===----------------------------------------------------------------------===//
235241
// HLO combined type definitions.

0 commit comments

Comments
 (0)