1818#include " mlir/Dialect/Utils/IndexingUtils.h"
1919#include " mlir/IR/BuiltinAttributes.h"
2020#include " mlir/IR/BuiltinTypes.h"
21+ #include " mlir/IR/DialectResourceBlobManager.h"
2122#include " mlir/IR/Matchers.h"
2223#include " mlir/Pass/Pass.h"
2324#include " llvm/ADT/APFloat.h"
@@ -176,13 +177,36 @@ DenseElementsAttr transposeType(const RangeType &data, ShapedType inputType,
176177 llvm::ArrayRef<ElementType>(outputValues));
177178}
178179
180+ // Try to get the values of a DenseResourceElementsAttr construct
181+ template <typename T>
182+ std::optional<ArrayRef<T>> tryGetDenseResourceValues (ElementsAttr attr) {
183+ if (auto denseResource = dyn_cast<DenseResourceElementsAttr>(attr)) {
184+ // Check that the resource memory blob exists
185+ AsmResourceBlob *blob = denseResource.getRawHandle ().getBlob ();
186+ if (!blob)
187+ return std::nullopt ;
188+
189+ // Check that the data are in a valid form
190+ bool isSplat = false ;
191+ if (!DenseElementsAttr::isValidRawBuffer (attr.getShapedType (),
192+ blob->getData (), isSplat)) {
193+ return std::nullopt ;
194+ }
195+
196+ return blob->template getDataAs <T>();
197+ }
198+
199+ return std::nullopt ;
200+ }
201+
179202// A type specialized transposition of an ElementsAttr.
180203// This implementation tries to operate on the underlying data in its raw
181204// representation when possible to avoid allocating a large number of Attribute
182205// objects.
183206DenseElementsAttr transpose (ElementsAttr attr, ShapedType inputType,
184207 ShapedType outputType,
185208 llvm::ArrayRef<int64_t > permValues) {
209+ // Handle generic ElementsAttr
186210 if (auto data = attr.tryGetValues <bool >())
187211 return transposeType (*data, inputType, outputType, permValues);
188212
@@ -204,6 +228,35 @@ DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
204228 if (auto data = attr.tryGetValues <APFloat>())
205229 return transposeType (*data, inputType, outputType, permValues);
206230
231+ // Handle DenseResourceElementsAttr
232+ if (isa<DenseResourceElementsAttr>(attr)) {
233+ auto elementTy = attr.getElementType ();
234+
235+ if (auto data = tryGetDenseResourceValues<bool >(attr);
236+ data && elementTy.isInteger (1 ))
237+ return transposeType (*data, inputType, outputType, permValues);
238+
239+ if (auto data = tryGetDenseResourceValues<int8_t >(attr);
240+ data && elementTy.isInteger (8 ))
241+ return transposeType (*data, inputType, outputType, permValues);
242+
243+ if (auto data = tryGetDenseResourceValues<int16_t >(attr);
244+ data && elementTy.isInteger (16 ))
245+ return transposeType (*data, inputType, outputType, permValues);
246+
247+ if (auto data = tryGetDenseResourceValues<int32_t >(attr);
248+ data && elementTy.isInteger (32 ))
249+ return transposeType (*data, inputType, outputType, permValues);
250+
251+ if (auto data = tryGetDenseResourceValues<int64_t >(attr);
252+ data && elementTy.isInteger (64 ))
253+ return transposeType (*data, inputType, outputType, permValues);
254+
255+ if (auto data = tryGetDenseResourceValues<float >(attr);
256+ data && elementTy.isF32 ())
257+ return transposeType (*data, inputType, outputType, permValues);
258+ }
259+
207260 return nullptr ;
208261}
209262
0 commit comments