@@ -330,25 +330,65 @@ LogicalResult LLVM::detail::oneToOneRewrite(
330330 return failure ();
331331 }
332332
333- // If the targetAttrs contains DenseElementsAttr,
334- // and the element type of the DenseElementsAttr and result type is
335- // inconsistent after the conversion of result types, we need to convert the
336- // element type of the DenseElementsAttr to the target type by creating a new
337- // DenseElementsAttr with the converted element type, and use the new
338- // DenseElementsAttr to replace the old one in the targetAttrs
333+ // Convert attribute element types to match the converted result types.
334+ // This ensures that attributes like
335+ // dense<0.0> : vector<4xf8E4M3FN> become
336+ // dense<0> : vector<4xi8>
337+ // when the result type is converted to i8.
339338 SmallVector<NamedAttribute> convertedAttrs;
340339 for (auto attr : targetAttrs) {
341- if (auto denseAttr = dyn_cast<DenseElementsAttr>(attr.getValue ())) {
342- VectorType vectorType = dyn_cast<VectorType>(denseAttr.getType ());
343- if (vectorType) {
344- auto convertedElementType =
345- typeConverter.convertType (vectorType.getElementType ());
346- VectorType convertedVectorType =
347- VectorType::get (vectorType.getShape (), convertedElementType,
348- vectorType.getScalableDims ());
340+ if (auto floatAttr = dyn_cast<FloatAttr>(attr.getValue ())) {
341+ auto convertedElementType =
342+ typeConverter.convertType (floatAttr.getType ());
343+ if (convertedElementType != floatAttr.getType ()) {
344+ // Currently, only 1-byte or sub-byte float types will be converted and
345+ // converted to integer types.
346+ convertedAttrs.emplace_back (
347+ attr.getName (),
348+ IntegerAttr::get (convertedElementType,
349+ floatAttr.getValue ().bitcastToAPInt ()));
350+ } else {
351+ convertedAttrs.emplace_back (attr);
352+ }
353+ } else if (auto intAttr = dyn_cast<IntegerAttr>(attr.getValue ())) {
354+ auto convertedElementType = typeConverter.convertType (intAttr.getType ());
355+ if (convertedElementType != intAttr.getType ()) {
349356 convertedAttrs.emplace_back (
350- attr.getName (), DenseElementsAttr::getFromRawBuffer (
351- convertedVectorType, denseAttr.getRawData ()));
357+ attr.getName (),
358+ IntegerAttr::get (convertedElementType, intAttr.getValue ()));
359+ } else {
360+ convertedAttrs.emplace_back (attr);
361+ }
362+ } else if (auto denseAttr = dyn_cast<DenseElementsAttr>(attr.getValue ())) {
363+ if (auto shapedType = dyn_cast<ShapedType>(denseAttr.getType ())) {
364+ auto convertedElementType =
365+ typeConverter.convertType (shapedType.getElementType ());
366+ if (convertedElementType != shapedType.getElementType ()) {
367+ ShapedType convertedShapedType =
368+ shapedType.cloneWith (std::nullopt , convertedElementType);
369+ convertedAttrs.emplace_back (
370+ attr.getName (), DenseElementsAttr::getFromRawBuffer (
371+ convertedShapedType, denseAttr.getRawData ()));
372+ } else {
373+ convertedAttrs.emplace_back (attr);
374+ }
375+ }
376+ } else if (auto sparseAttr =
377+ dyn_cast<SparseElementsAttr>(attr.getValue ())) {
378+ if (auto shapedType = dyn_cast<ShapedType>(sparseAttr.getType ())) {
379+ auto convertedElementType =
380+ typeConverter.convertType (shapedType.getElementType ());
381+ if (convertedElementType != shapedType.getElementType ()) {
382+ ShapedType convertedShapedType =
383+ shapedType.cloneWith (std::nullopt , convertedElementType);
384+ convertedAttrs.emplace_back (
385+ attr.getName (),
386+ SparseElementsAttr::get (
387+ convertedShapedType, sparseAttr.getIndices (),
388+ sparseAttr.getValues ().bitcast (convertedElementType)));
389+ } else {
390+ convertedAttrs.emplace_back (attr);
391+ }
352392 }
353393 } else {
354394 convertedAttrs.push_back (attr);
0 commit comments