@@ -131,62 +131,48 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
131
131
return llvm::cast<TensorDescType>(cloneWith(getShape(), elementType));
132
132
}
133
133
134
- BlockTensorDescAttr getEncodingAsBlockTensorDescAttr() const {
135
- return llvm::dyn_cast_if_present<BlockTensorDescAttr>(getEncoding());
136
- }
137
-
138
- ScatterTensorDescAttr getEncodingAsScatterTensorDescAttr () const {
139
- return llvm::dyn_cast_if_present<ScatterTensorDescAttr >(getEncoding());
134
+ template <typename T,
135
+ typename = std::enable_if_t<
136
+ std::is_same_v<T, BlockTensorDescAttr> ||
137
+ std::is_same_v<T, ScatterTensorDescAttr>>>
138
+ T getEncodingOfType () const {
139
+ return llvm::dyn_cast_if_present<T >(getEncoding());
140
140
}
141
141
142
142
LayoutAttr getLayoutAttr() const {
143
143
return llvm::dyn_cast_if_present<LayoutAttr>(getLayout());
144
144
}
145
145
146
146
xegpu::MemorySpace getMemorySpace() const {
147
- auto block_attr = getEncodingAsBlockTensorDescAttr();
148
- if (block_attr && block_attr.getMemorySpace())
149
- return block_attr.getMemorySpace().getValue();
150
-
151
- auto scatter_attr = getEncodingAsScatterTensorDescAttr();
152
- if (scatter_attr && scatter_attr.getMemorySpace())
153
- return scatter_attr.getMemorySpace().getValue();
147
+ if (auto attr = getEncodingOfType<BlockTensorDescAttr>())
148
+ return attr.getMemorySpace().getValue();
154
149
155
- // return default value
156
- return MemorySpace::Global ;
150
+ auto attr = getEncodingOfType<ScatterTensorDescAttr>();
151
+ return attr.getMemorySpace().getValue() ;
157
152
}
158
153
159
154
// get the ArrayLength for blocked TensorDesc
160
155
int getArrayLength() {
161
- auto attr = getEncoding();
162
- auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
163
- assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
164
- if (block_attr && block_attr.getArrayLength())
165
- return block_attr.getArrayLength().getInt();
166
- // return default value
167
- return 1;
156
+ auto attr = getEncodingOfType<BlockTensorDescAttr>();
157
+ assert(attr && "invalid on non BlockTensorDescAttr.");
158
+ return attr.getArrayLength().getInt();
168
159
}
169
160
170
161
bool getBoundaryCheck() {
171
- auto attr = getEncoding();
172
- auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
173
- assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
174
- if (block_attr && block_attr.getBoundaryCheck())
175
- return block_attr.getBoundaryCheck().getValue();
176
- // return default value
177
- return true;
162
+ auto attr = getEncodingOfType<BlockTensorDescAttr>();
163
+ assert(attr && "invalid on non BlockTensorDescAttr.");
164
+ return attr.getBoundaryCheck().getValue();
178
165
}
179
166
180
167
bool isScattered() {
181
- return bool(getEncodingAsScatterTensorDescAttr ());
168
+ return bool(getEncodingOfType<ScatterTensorDescAttr> ());
182
169
}
183
170
184
171
// get the ChunkSize for scattered TensorDesc
185
172
int getChunkSizeAsInt() {
186
- auto attr = getEncoding();
187
- auto scatter_attr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(attr);
188
- assert(scatter_attr && "invalid on non ScatterTensorDescAttr.");
189
- return scatter_attr.getChunkSizeAsInt();
173
+ auto attr = getEncodingOfType<ScatterTensorDescAttr>();
174
+ assert(attr && "invalid on non ScatterTensorDescAttr.");
175
+ return attr.getChunkSizeAsInt();
190
176
}
191
177
192
178
/// Helper to drop all layout information from the TensorDesc type.
0 commit comments