| 
13 | 13 | namespace mlir {  | 
14 | 14 | 
 
  | 
15 | 15 | class VectorType;  | 
 | 16 | +class OpOperand;  | 
 | 17 | +class OpResult;  | 
 | 18 | +class OpBuilder;  | 
 | 19 | +class ValueRange;  | 
 | 20 | +class TypeConverter;  | 
 | 21 | + | 
16 | 22 | namespace xegpu {  | 
17 | 23 | class LayoutAttr;  | 
18 | 24 | class TensorDescType;  | 
@@ -50,6 +56,59 @@ FailureOr<VectorType> getDistributedVectorType(xegpu::TensorDescType tdescTy);  | 
50 | 56 | FailureOr<VectorType> getDistributedVectorType(VectorType originalType,  | 
51 | 57 |                                                LayoutAttr layout);  | 
52 | 58 | 
 
  | 
 | 59 | +/// Return the attribute name for the OpOperand to attach LayoutAttr  | 
 | 60 | +std::string getLayoutName(const OpOperand &operand);  | 
 | 61 | + | 
 | 62 | +/// Return the attribute name for the OpResult to attach LayoutAttr  | 
 | 63 | +std::string getLayoutName(const OpResult result);  | 
 | 64 | + | 
 | 65 | +/// Retrieves the LayoutAttr associated with a given Value. For TensorDescType  | 
 | 66 | +/// values, the LayoutAttr is extracted from the TensorDescType itself. For  | 
 | 67 | +/// other values, it is obtained from the attributes of the defining operation.  | 
 | 68 | +/// Returns nullptr if no LayoutAttr is found.  | 
 | 69 | +LayoutAttr getLayoutAttr(const Value value);  | 
 | 70 | + | 
 | 71 | +/// Retrieves the LayoutAttr associated with a given OpOperand. It will  | 
 | 72 | +/// first check the operand_layout_{id} of the owner operation. If not found,  | 
 | 73 | +/// it will check the operand itself and its defining op.  | 
 | 74 | +LayoutAttr getLayoutAttr(const OpOperand &opr);  | 
 | 75 | + | 
 | 76 | +/// Sets the LayoutAttr for a given OpOperand or OpResult by attaching  | 
 | 77 | +/// it to the owner's dictionary attributes  | 
 | 78 | +template <typename T,  | 
 | 79 | +          typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||  | 
 | 80 | +                                      std::is_same_v<T, OpResult>>>  | 
 | 81 | +void setLayoutAttr(const T &operandOrResult, const LayoutAttr layout);  | 
 | 82 | + | 
 | 83 | +/// Set the LayoutAttr for each OpOperand and OpResult of the given operation.  | 
 | 84 | +/// If the operation contains regions, it is also applied recursively to the  | 
 | 85 | +/// contained operations  | 
 | 86 | +void setLayoutAttrs(Operation *op,  | 
 | 87 | +                    function_ref<LayoutAttr(Value)> getLayoutImpl);  | 
 | 88 | + | 
 | 89 | +/// Extract a set of small vectors from a value with a given shape using  | 
 | 90 | +/// vector.extract_stride_slice  | 
 | 91 | +SmallVector<Value> extractVectorsWithShapeFromValue(OpBuilder &builder,  | 
 | 92 | +                                                    Location loc, Value value,  | 
 | 93 | +                                                    ArrayRef<int64_t> shape);  | 
 | 94 | + | 
 | 95 | +/// Create a vector of shape from a set of values using  | 
 | 96 | +/// vector.insert_stride_slice.  | 
 | 97 | +Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc,  | 
 | 98 | +                                      ValueRange values,  | 
 | 99 | +                                      ArrayRef<int64_t> shape);  | 
 | 100 | + | 
 | 101 | +/// Do type conversion for SCF structural ops, e.g., scf.for using SCF structure  | 
 | 102 | +/// type convertion patterns. Since VectorType cannot carry the layout  | 
 | 103 | +/// attribute, which is needed to guide the type conversion for XeGPU, they are  | 
 | 104 | +/// first converted into RankedTensorType, where the layout attribute can be  | 
 | 105 | +/// attached. And then upstream SCF structural type conversion patterns are  | 
 | 106 | +/// applied with the provided converter.  | 
 | 107 | +/// TODO: This is a temporary solution. We should refactor it when context-aware  | 
 | 108 | +/// type conversion is available.  | 
 | 109 | +void doSCFStructuralTypeConversionWithTensorType(Operation *op,  | 
 | 110 | +                                                 TypeConverter converter);  | 
 | 111 | + | 
53 | 112 | } // namespace xegpu  | 
54 | 113 | 
 
  | 
55 | 114 | } // namespace mlir  | 
 | 
0 commit comments