Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -476,10 +476,10 @@ inferContractionDimsImpl(ArrayRef<AffineMap> indexingMaps,
SmallVector<unsigned, 2>(ac.begin(), ac.end()),
SmallVector<unsigned, 2>(bc.begin(), bc.end()),
SmallVector<unsigned, 2>(ra.begin(), ra.end())};
llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
llvm::sort(dimensions.m.begin(), dimensions.m.end());
llvm::sort(dimensions.n.begin(), dimensions.n.end());
llvm::sort(dimensions.k.begin(), dimensions.k.end());
llvm::sort(dimensions.batch);
llvm::sort(dimensions.m);
llvm::sort(dimensions.n);
llvm::sort(dimensions.k);
return dimensions;
}

Expand Down Expand Up @@ -797,12 +797,12 @@ inferConvolutionDimsImpl(LinalgOp linalgOp,
SmallVector<unsigned, 2>(depth.begin(), depth.end()),
/*strides=*/SmallVector<int64_t, 2>{},
/*dilations=*/SmallVector<int64_t, 2>{}};
llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
llvm::sort(dimensions.outputImage.begin(), dimensions.outputImage.end());
llvm::sort(dimensions.outputChannel.begin(), dimensions.outputChannel.end());
llvm::sort(dimensions.filterLoop.begin(), dimensions.filterLoop.end());
llvm::sort(dimensions.inputChannel.begin(), dimensions.inputChannel.end());
llvm::sort(dimensions.depth.begin(), dimensions.depth.end());
llvm::sort(dimensions.batch);
llvm::sort(dimensions.outputImage);
llvm::sort(dimensions.outputChannel);
llvm::sort(dimensions.filterLoop);
llvm::sort(dimensions.inputChannel);
llvm::sort(dimensions.depth);

// Use the op carried strides/dilations attribute if present.
auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
Expand Down
45 changes: 22 additions & 23 deletions mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3877,29 +3877,28 @@ static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo,
llvm::SmallVector<size_t> indices(indexAttr.size());
std::iota(indices.begin(), indices.end(), 0);

llvm::sort(indices.begin(), indices.end(),
[&](const size_t a, const size_t b) {
auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt();
int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt();

if (aIndex == bIndex)
continue;

if (aIndex < bIndex)
return first;

if (aIndex > bIndex)
return !first;
}

// Iterated the up until the end of the smallest member and
// they were found to be equal up to that point, so select
// the member with the lowest index count, so the "parent"
return memberIndicesA.size() < memberIndicesB.size();
});
llvm::sort(indices, [&](const size_t a, const size_t b) {
auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt();
int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt();

if (aIndex == bIndex)
continue;

if (aIndex < bIndex)
return first;

if (aIndex > bIndex)
return !first;
}

// Iterated the up until the end of the smallest member and
// they were found to be equal up to that point, so select
// the member with the lowest index count, so the "parent"
return memberIndicesA.size() < memberIndicesB.size();
});

return llvm::cast<omp::MapInfoOp>(
mapInfo.getMembers()[indices.front()].getDefiningOp());
Expand Down