|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from typing import List, Tuple, Sequence |
| 4 | +from dataclasses import dataclass |
| 5 | +import triton.language.core as tl_core |
| 6 | + |
| 7 | +import triton.experimental.gluon.language._core as ttgl |
| 8 | +from triton.experimental.gluon.language._layouts import DotOperandLayout |
| 9 | +from triton.experimental.gluon.language.intel._layouts import IntelDPASLayout |
| 10 | +from triton.experimental.gluon.language._core import builtin, _unwrap_if_constexpr |
| 11 | +from triton.language.core import ir, constexpr, tensor_descriptor_base, block_type, tensor, tuple |
| 12 | + |
| 13 | + |
| 14 | +# load_tensor_descriptor = builtin(tl_core.load_tensor_descriptor) |
| 15 | +# store_tensor_descriptor = builtin(tl_core.store_tensor_descriptor) |
| 16 | + |
| 17 | + |
| 18 | +__all__ = ["make_tensor_descriptor", "dot_fma"] |
| 19 | + |
| 20 | + |
| 21 | + |
| 22 | +class tensor_descriptor(tensor_descriptor_base): |
| 23 | + """A descriptor representing a tensor in global memory.""" |
| 24 | + |
| 25 | + def __init__(self, handle, shape: List[tensor], strides: List[tensor], block_type: block_type, layout): |
| 26 | + """Not called by user code.""" |
| 27 | + # IR handle |
| 28 | + super().__init__(handle, block_type) |
| 29 | + # Global shape |
| 30 | + self.shape = tuple(shape) |
| 31 | + self.strides = tuple(strides) |
| 32 | + self.layout = layout |
| 33 | + |
| 34 | + self.type = tensor_descriptor_type( |
| 35 | + block_type, |
| 36 | + shape_type=self.shape.type, |
| 37 | + strides_type=self.strides.type, |
| 38 | + layout=self.layout, # comment |
| 39 | + ) |
| 40 | + |
| 41 | + def _flatten_ir(self, handles: List[ir.value]) -> None: |
| 42 | + handles.append(self.handle) |
| 43 | + self.shape._flatten_ir(handles) |
| 44 | + self.strides._flatten_ir(handles) |
| 45 | + |
| 46 | + # TODO: MaterializeBlockPointers.cpp |
| 47 | + # Add 2d_block_io parameter + validation to set proper attribute |
| 48 | + # Validation: (?) |
| 49 | + # > 2 dims |
| 50 | + # > stride 16 bytes aligned |
| 51 | + # and others |
| 52 | + @builtin |
| 53 | + def load(self, offsets: Sequence[constexpr | tensor], is_2d_block=False, _semantic=None) -> tensor: |
| 54 | + op = _semantic.descriptor_load(self, offsets, "", "") |
| 55 | + |
| 56 | + if is_2d_block: |
| 57 | + # TODO: proper handling like below test example |
| 58 | + # Option to set row/column major and other params |
| 59 | + attr = _semantic.builder.get_string_attr("row_major") |
| 60 | + op.handle.set_attr("ttig.block_io", attr) |
| 61 | + |
| 62 | + return op |
| 63 | + |
| 64 | + @builtin |
| 65 | + def store(self, offsets: Sequence[constexpr | tensor], value: tensor, is_2d_block=False, _semantic=None) -> tensor: |
| 66 | + op = _semantic.descriptor_store(self, value, offsets) |
| 67 | + |
| 68 | + if is_2d_block: |
| 69 | + attr = _semantic.builder.get_string_attr("row_major") |
| 70 | + op.handle.set_attr("ttig.block_io", attr) |
| 71 | + |
| 72 | + return op |
| 73 | + |
| 74 | + @builtin |
| 75 | + def prefetch(self, offsets: Sequence[constexpr | tensor], mask=None, cache=None, evict=None, is_volatile=False, is_2d_block=False, _semantic=None): |
| 76 | + # TODO: handle other ttig.prefetch params |
| 77 | + # ptr is just temporary, support for tensor descriptor is needed |
| 78 | + # calculate offsets like tt.advance |
| 79 | + # maybe add support for mask, seems optional |
| 80 | + # also 2d block attr and others |
| 81 | + #return _semantic.builder.create_prefetch(ptr.handle, False) |
| 82 | + |
| 83 | + """ |
| 84 | + pyton/triton/language/semantic.py @ load:1077 (TritonSemantic) |
| 85 | + cache_modifier: str, eviction_policy: str |
| 86 | + cache = self._str_to_load_cache_modifier(cache_modifier) |
| 87 | + eviction = self._str_to_eviction_policy(eviction_policy) |
| 88 | + """ |
| 89 | + |
| 90 | + ptr_handle = self.handle |
| 91 | + offsets_handles = [offset.handle if hasattr(offset, 'handle') else offset for offset in offsets] |
| 92 | + op = _semantic.builder.create_prefetch(ptr_handle, offsets_handles, False) |
| 93 | + |
| 94 | + if is_2d_block: |
| 95 | + attr = _semantic.builder.get_string_attr("row_major") |
| 96 | + op.set_attr("ttig.block_io", attr) |
| 97 | + |
| 98 | + return op |
| 99 | + |
| 100 | + |
| 101 | + |
| 102 | +@dataclass(eq=True) |
| 103 | +class tensor_descriptor_type(ttgl.base_type): |
| 104 | + """The type for a tensor descriptor.""" |
| 105 | + |
| 106 | + block_type: ttgl.block_type |
| 107 | + shape_type: ttgl.tuple_type |
| 108 | + strides_type: ttgl.tuple_type |
| 109 | + layout: IntelDPASLayout |
| 110 | + |
| 111 | + def __str__(self) -> str: |
| 112 | + return f"tensor_descriptor<{self.block_type}, {self.layout}>" |
| 113 | + |
| 114 | + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor, int]: |
| 115 | + handle = handles[cursor] |
| 116 | + cursor += 1 |
| 117 | + shape, cursor = self.shape_type._unflatten_ir(handles, cursor) |
| 118 | + strides, cursor = self.strides_type._unflatten_ir(handles, cursor) |
| 119 | + value = tensor_descriptor(handle, shape, strides, self.block_type, self.layout) |
| 120 | + return value, cursor |
| 121 | + |
| 122 | + def _to_ir(self, builder: ir.builder) -> ir.type: |
| 123 | + is_signed = self.block_type.element_ty.is_int_signed() |
| 124 | + return builder.get_tensor_descriptor_layout_type( |
| 125 | + self.block_type.to_ir(builder), |
| 126 | + is_signed, |
| 127 | + self.layout._to_ir(builder), |
| 128 | + ) |
| 129 | + |
| 130 | + def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: |
| 131 | + out.append(self._to_ir(builder)) |
| 132 | + self.shape_type._flatten_ir_types(builder, out) |
| 133 | + self.strides_type._flatten_ir_types(builder, out) |
| 134 | + |
| 135 | + def mangle(self) -> str: |
| 136 | + return f"TD{self.block_type.mangle()}_{self.shape_type.mangle()}_{self.strides_type.mangle()}_{self.layout.mangle()}TD" |
| 137 | + |
| 138 | + |
| 139 | +@builtin |
| 140 | +def make_tensor_descriptor(ptr: ttgl.tensor, shape: List[int], strides: List[int], |
| 141 | + block_shape: List[int], layout: IntelDPASLayout, |
| 142 | + _semantic=None) -> tensor_descriptor: |
| 143 | + # Unwrap constexpr if needed |
| 144 | + layout = _unwrap_if_constexpr(layout) |
| 145 | + |
| 146 | + # Get the pointer handle directly |
| 147 | + ptr_handle = ptr.handle |
| 148 | + |
| 149 | + # Convert shape and strides to IR values AND create tensor objects |
| 150 | + shape_handles = _semantic._convert_to_ir_values(shape, require_i64=False) |
| 151 | + stride_handles = _semantic._convert_to_ir_values(strides, require_i64=True) |
| 152 | + |
| 153 | + # Create tensor objects from the handles |
| 154 | + shape_tensors = [ttgl.tensor(h, ttgl.int32) for h in shape_handles] |
| 155 | + stride_tensors = [ttgl.tensor(h, ttgl.int64) for h in stride_handles] |
| 156 | + |
| 157 | + # Build type information |
| 158 | + block_type = ttgl.block_type(ptr.type.element_ty, block_shape) |
| 159 | + |
| 160 | + # TODO: this is w/a for xpu_dot_fma assertion - layout for block_type is not implemented yet |
| 161 | + # See: gluon/language/_core.py:19 |
| 162 | + block_type.layout = layout |
| 163 | + |
| 164 | + shape_type = ttgl.tuple_type([ttgl.int32] * len(shape)) |
| 165 | + strides_type = ttgl.tuple_type([ttgl.int64] * len(strides)) |
| 166 | + |
| 167 | + # Pass tensor objects, not constexpr values |
| 168 | + shape_tuple = ttgl.tuple(shape_tensors, shape_type) |
| 169 | + strides_tuple = ttgl.tuple(stride_tensors, strides_type) |
| 170 | + |
| 171 | + desc_type = tensor_descriptor_type(block_type, shape_type, strides_type, layout) #, shape_handles) |
| 172 | + |
| 173 | + # Create the descriptor |
| 174 | + padding = _semantic._str_to_padding_option("zero") |
| 175 | + desc_handle = _semantic.builder.create_make_tensor_descriptor( |
| 176 | + desc_type._to_ir(_semantic.builder), |
| 177 | + ptr_handle, |
| 178 | + shape_handles, |
| 179 | + stride_handles, |
| 180 | + padding |
| 181 | + ) |
| 182 | + |
| 183 | + return tensor_descriptor(desc_handle, shape_tuple, strides_tuple, block_type, layout) |
| 184 | + |
| 185 | +@builtin |
| 186 | +def dot_fma(a, b, acc, _semantic=None): |
| 187 | + assert isinstance(a, tensor), "a must be a tensor" |
| 188 | + assert isinstance(b, tensor), "b must be a tensor" |
| 189 | + assert isinstance(acc, tensor), "acc must be a tensor" |
| 190 | + |
| 191 | + mma_layout = acc.type.layout |
| 192 | + assert isinstance(mma_layout, IntelDPASLayout), "acc must have a BlockedLayout" |
| 193 | + assert isinstance(a.type.layout, DotOperandLayout), "a must have a DotOperandLayout" |
| 194 | + assert isinstance(b.type.layout, DotOperandLayout), "b must have a DotOperandLayout" |
| 195 | + assert a.type.layout.parent == mma_layout, "a's parent layout must be the same as acc's layout" |
| 196 | + assert b.type.layout.parent == mma_layout, "b's parent layout must be the same as acc's layout" |
| 197 | + assert a.type.layout.operand_index == 0, "a's operand index must be 0" |
| 198 | + assert b.type.layout.operand_index == 1, "b's operand index must be 1" |
| 199 | + |
| 200 | + handle = _semantic.dot(a, b, acc, input_precision=None, max_num_imprecise_acc=None, out_dtype=acc.dtype).handle |
| 201 | + return tensor(handle, acc.type) |
| 202 | + |
0 commit comments