@@ -89,7 +89,7 @@ completely numeric to simplify generation of StableHLO programs.
8989
9090``` ebnf
9191Type ::= ValueType | NonValueType
92- ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType
92+ ValueType ::= TensorType | QuantizedTensorType | TokenType | TupleType | BufferType
9393NonValueType ::= TensorElementType | QuantizedTensorElementType | FunctionType | StringType
9494```
9595
@@ -229,6 +229,21 @@ TupleType ::= 'tuple' '<' TupleElementTypes '>'
229229TupleElementTypes ::= [ValueType {',' ValueType}]
230230```
231231
232+ ** Buffer types** represent buffers. For example, in XLA, buffers are
233+ multidimensional arrays with consistent storage. Similar to ** tensor types** ,
234+ buffer types have a ** shape** and an ** element type** , where a shape represents
235+ non-negative or unknown ** dimension sizes** in the ascending order of the
236+ corresponding ** dimensions** (which are also called ** axes** ) numbered from ` 0 `
237+ to ` R-1 ` . The number of dimensions ` R ` is called ** rank** . For example,
238+ ` memref<2x3xf32> ` is a buffer type with shape ` 2x3 ` and element type ` f32 ` . It
239+ has two dimensions (or, in other words, two axes) - 0th dimension and 1st
240+ dimension - whose sizes are 2 and 3. Its rank is 2.
241+
242+ Buffers can be allocated using a ` custom_call ` to ` CreateBuffer ` or ` Pin ` and
243+ deallocated via a ` custom_call ` to ` Unpin ` . Only ` custom_call ` ops can read and
244+ write the content inside buffers. See [ custom_call] ( #custom_call ) for more
245+ detail.
246+
232247** Tuple types** represent tuples, i.e. heterogeneous lists. Tuples are a legacy
233248feature which only exists for compatibility with HLO. In HLO, tuples are
234249used to represent variadic inputs and outputs. In StableHLO, variadic inputs and
@@ -2433,21 +2448,63 @@ the XLA compiler. In the future, we are planning to unify this metadata
24332448
24342449#### Inputs
24352450
2436- | Label | Name | Type |
2437- | -------| -----------------------| ---------------------------------------------------|
2438- | (I1) | ` inputs ` | variadic number of values |
2439- | (I2) | ` call_target_name ` | constant of type ` string ` |
2440- | (I3) | ` has_side_effect ` | constant of type ` i1 ` |
2441- | (I4) | ` backend_config ` | constant of type ` string ` or attribute dictionary |
2442- | (I5) | ` api_version ` | constant of type ` si32 ` |
2443- | (I6) | ` called_computations ` | variadic number of constants of type ` string ` |
2451+ | Label | Name | Type |
2452+ | -------| --------------------------| ------------------------------------------------------------|
2453+ | (I1) | ` inputs ` | variadic number of values |
2454+ | (I2) | ` call_target_name ` | constant of type ` string ` |
2455+ | (I3) | ` has_side_effect ` | constant of type ` i1 ` |
2456+ | (I4) | ` backend_config ` | constant of type ` string ` or attribute dictionary |
2457+ | (I5) | ` api_version ` | constant of type ` si32 ` |
2458+ | (I6) | ` called_computations ` | variadic number of constants of type ` string ` |
2459+ | (I7) | ` output_operand_aliases ` | specify the aliasing parts in the outputs and operands |
24442460
24452461#### Outputs
24462462
24472463| Name | Type |
24482464| -----------| ---------------------------|
24492465| ` results ` | variadic number of values |
24502466
2467+ ### (XLA GPU Support) Special custom_call targets
2468+
2469+ There are three special ` call_target_name ` related to ` buffer ` types:
2470+ ` CreateBuffer ` creates an uninitialized ` buffer ` , ` Pin ` creates an initialized
2471+ ` buffer ` and ` Unpin ` deallocates a ` buffer ` and returns the content of the
2472+ ` buffer ` .
2473+
2474+ ``` mlir
2475+ %uninitialized_buffer = "stablehlo.custom_call"() {
2476+ call_target_name = "CreateBuffer",
2477+ api_version = 4 : i32,
2478+ } : () -> memref<4xf64>
2479+
2480+ %initialized_buffer = "stablehlo.custom_call"(%init_value) {
2481+ call_target_name = "Pin",
2482+ api_version = 4 : i32,
2483+ } : (tensor<4xf64>) -> memref<4xf64>
2484+
2485+ %dealloc_buffer = "stablehlo.custom_call"(%initialized_buffer) {
2486+ call_target_name = "Unpin",
2487+ api_version = 4 : i32,
2488+ } : (memref<4xf64>) -> tensor<4xf64>
2489+
2490+ ```
2491+
2492+ ### Alias
2493+
2494+ Some custom_call ops may require a part in the outputs and a part in the
2495+ operands to share the same memory. This can be expressed via
2496+ ` output_operand_aliases ` . An alias pair representation consists a list of output
2497+ tuple indices representing the output part, and an operand_index along with a
2498+ list of operand tuple indices representing the operand part. The list of output
2499+ or operand tuple indices is empty if the corresponding type is not a ` tuple `
2500+ type, and can be arbitrarily long for an arbitrarily nested tuple type. This
2501+ is similar to [ the XLA alias representation] ( https://www.tensorflow.org/xla/aliasing ) .
2502+
2503+ The output part and the input part in an alias pair must have the same type. For
2504+ custom_call ops that aren't call to ` CreateBuffer ` , ` Pin ` and ` Unpin ` , a
2505+ ` buffer ` operand can appear in at most one pair of alias, and a ` buffer ` output
2506+ must appear in one pair of alias.
2507+
24512508#### Examples
24522509
24532510``` mlir
@@ -2458,6 +2515,16 @@ the XLA compiler. In the future, we are planning to unify this metadata
24582515 api_version = 4 : i32,
24592516 called_computations = [@foo]
24602517} : (tensor<f64>) -> tensor<f64>
2518+
2519+ %updated_buffer = "stablehlo.custom_call"(%buffer) {
2520+ call_target_name = "Update",
2521+ api_version = 4 : i32,
2522+ output_operand_aliases = [
2523+ #stablehlo.output_operand_alias<output_tuple_indices = [],
2524+ operand_index = 0,
2525+ operand_tuple_indices = []>]
2526+ } : (memref<4xf64>) -> memref<4xf64>
2527+
24612528```
24622529
24632530### divide
@@ -3780,9 +3847,9 @@ Extracts element at `index` position of the `operand` tuple and produces a
37803847
37813848#### Outputs
37823849
3783- | Name | Type | Constraints |
3784- | ----------| --------------------| -------------|
3785- | ` result ` | any supported type | (C2) |
3850+ | Name | Type | Constraints |
3851+ | ----------| ------------------------ | -------------|
3852+ | ` result ` | any value | (C2) |
37863853
37873854#### Constraints
37883855
@@ -6583,10 +6650,10 @@ Produces a `result` tuple from values `val`.
65836650#### Examples
65846651
65856652``` mlir
6586- // %val0: [1.0, 2.0]
6653+ // %val0: memref [1.0, 2.0]
65876654// %val1: (3)
6588- %result = "stablehlo.tuple"(%val0, %val1) : (tensor <2xf32>, tuple<tensor<i32>>) -> tuple<tensor <2xf32>, tuple<tensor<i32>>>
6589- // %result: ([1.0, 2.0], (3))
6655+ %result = "stablehlo.tuple"(%val0, %val1) : (memref <2xf32>, tuple<tensor<i32>>) -> tuple<memref <2xf32>, tuple<tensor<i32>>>
6656+ // %result: (memref [1.0, 2.0], (3))
65906657```
65916658
65926659  ; [ More Examples] ( https://github.com/openxla/stablehlo/tree/main/stablehlo/tests/interpret/tuple_and_get_tuple_element.mlir )
@@ -6692,17 +6759,17 @@ The behavior of an infinite loop is TBD
66926759
66936760#### Inputs
66946761
6695- | Label | Name | Type | Constraints |
6696- | -------| -----------| --------------------------------------------------------- | -------------|
6697- | (I1) | ` operand ` | variadic number of tensors, quantized tensors or tokens | (C1-C3) |
6698- | (I2) | ` cond ` | function | (C1) |
6699- | (I3) | ` body ` | function | (C2) |
6762+ | Label | Name | Type | Constraints |
6763+ | -------| -----------| -----------------------------------------| -------------|
6764+ | (I1) | ` operand ` | variadic number of values | (C1-C3) |
6765+ | (I2) | ` cond ` | function | (C1) |
6766+ | (I3) | ` body ` | function | (C2) |
67006767
67016768#### Outputs
67026769
6703- | Name | Type | Constraints |
6704- | -----------| --------------------------------------------------------- | -------------|
6705- | ` results ` | variadic number of tensors, quantized tensors or tokens | (C3) |
6770+ | Name | Type | Constraints |
6771+ | -----------| -------------------------------------------------| -------------|
6772+ | ` results ` | variadic number of values | (C3) |
67066773
67076774#### Constraints
67086775
0 commit comments