@@ -131,62 +131,51 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
131131 return llvm::cast<TensorDescType>(cloneWith(getShape(), elementType));
132132 }
133133
134- BlockTensorDescAttr getEncodingAsBlockTensorDescAttr() const {
135- return llvm::dyn_cast_if_present<BlockTensorDescAttr>(getEncoding());
136- }
137-
138- ScatterTensorDescAttr getEncodingAsScatterTensorDescAttr () const {
139- return llvm::dyn_cast_if_present<ScatterTensorDescAttr >(getEncoding());
134+ template <typename T,
135+ typename = std::enable_if_t<
136+ std::is_same_v<T, BlockTensorDescAttr> ||
137+ std::is_same_v<T, ScatterTensorDescAttr>>>
138+ T getEncodingOfType () const {
139+ return llvm::dyn_cast_if_present<T >(getEncoding());
140140 }
141141
142142 LayoutAttr getLayoutAttr() const {
143143 return llvm::dyn_cast_if_present<LayoutAttr>(getLayout());
144144 }
145145
146146 xegpu::MemorySpace getMemorySpace() const {
147- auto block_attr = getEncodingAsBlockTensorDescAttr();
148- if (block_attr && block_attr.getMemorySpace())
149- return block_attr.getMemorySpace().getValue();
147+ if (auto attr = getEncodingOfType<BlockTensorDescAttr>())
148+ return attr.getMemorySpace().getValue();
150149
151- auto scatter_attr = getEncodingAsScatterTensorDescAttr();
152- if (scatter_attr && scatter_attr.getMemorySpace())
153- return scatter_attr.getMemorySpace().getValue();
150+ if (auto attr = getEncodingOfType<ScatterTensorDescAttr>())
151+ return attr.getMemorySpace().getValue();
154152
155- // return default value
153+ llvm_unreachable("invalid encoding");
156154 return MemorySpace::Global;
157155 }
158156
159157 // get the ArrayLength for blocked TensorDesc
160158 int getArrayLength() {
161- auto attr = getEncoding();
162- auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
163- assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
164- if (block_attr && block_attr.getArrayLength())
165- return block_attr.getArrayLength().getInt();
166- // return default value
167- return 1;
159+ auto attr = getEncodingOfType<BlockTensorDescAttr>();
160+ assert(attr && "invalid on non BlockTensorDescAttr.");
161+ return attr.getArrayLength().getInt();
168162 }
169163
170164 bool getBoundaryCheck() {
171- auto attr = getEncoding();
172- auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
173- assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
174- if (block_attr && block_attr.getBoundaryCheck())
175- return block_attr.getBoundaryCheck().getValue();
176- // return default value
177- return true;
165+ auto attr = getEncodingOfType<BlockTensorDescAttr>();
166+ assert(attr && "invalid on non BlockTensorDescAttr.");
167+ return attr.getBoundaryCheck().getValue();
178168 }
179169
180170 bool isScattered() {
181- return bool(getEncodingAsScatterTensorDescAttr ());
171+ return bool(getEncodingOfType<ScatterTensorDescAttr> ());
182172 }
183173
184174 // get the ChunkSize for scattered TensorDesc
185175 int getChunkSizeAsInt() {
186- auto attr = getEncoding();
187- auto scatter_attr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(attr);
188- assert(scatter_attr && "invalid on non ScatterTensorDescAttr.");
189- return scatter_attr.getChunkSizeAsInt();
176+ auto attr = getEncodingOfType<ScatterTensorDescAttr>();
177+ assert(attr && "invalid on non ScatterTensorDescAttr.");
178+ return attr.getChunkSizeAsInt();
190179 }
191180
192181 /// Helper to drop all layout information from the TensorDesc type.
0 commit comments