diff --git a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp index b1f3f41a28e8b..0b7cf2f970172 100644 --- a/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp +++ b/llvm/lib/Target/DirectX/DXILFlattenArrays.cpp @@ -86,6 +86,13 @@ class DXILFlattenArraysVisitor Value *genInstructionFlattenIndices(ArrayRef Indices, ArrayRef Dims, IRBuilder<> &Builder); + + // Helper function to collect indices and dimensions from a GEP instruction + void collectIndicesAndDimsFromGEP(GetElementPtrInst &GEP, + SmallVectorImpl &Indices, + SmallVectorImpl &Dims, + bool &AllIndicesAreConstInt); + void recursivelyCollectGEPs(GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType, Value *PtrOperand, @@ -218,6 +225,26 @@ bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) { return true; } +void DXILFlattenArraysVisitor::collectIndicesAndDimsFromGEP( + GetElementPtrInst &GEP, SmallVectorImpl &Indices, + SmallVectorImpl &Dims, bool &AllIndicesAreConstInt) { + + Type *CurrentType = GEP.getSourceElementType(); + + // Note index 0 is the ptr index. + for (Value *Index : llvm::drop_begin(GEP.indices(), 1)) { + Indices.push_back(Index); + AllIndicesAreConstInt &= isa(Index); + + if (auto *ArrayTy = dyn_cast(CurrentType)) { + Dims.push_back(ArrayTy->getNumElements()); + CurrentType = ArrayTy->getElementType(); + } else { + assert(false && "Expected array type in GEP chain"); + } + } +} + void DXILFlattenArraysVisitor::recursivelyCollectGEPs( GetElementPtrInst &CurrGEP, ArrayType *FlattenedArrayType, Value *PtrOperand, unsigned &GEPChainUseCount, SmallVector Indices, @@ -226,12 +253,8 @@ void DXILFlattenArraysVisitor::recursivelyCollectGEPs( if (GEPChainMap.count(&CurrGEP) > 0) return; - Value *LastIndex = CurrGEP.getOperand(CurrGEP.getNumOperands() - 1); - AllIndicesAreConstInt &= isa(LastIndex); - Indices.push_back(LastIndex); - assert(isa(CurrGEP.getSourceElementType())); - Dims.push_back( - cast(CurrGEP.getSourceElementType())->getNumElements()); + // Collect indices and dimensions from the current GEP + collectIndicesAndDimsFromGEP(CurrGEP, Indices, Dims, AllIndicesAreConstInt); bool IsMultiDimArr = isMultiDimensionalArray(CurrGEP.getSourceElementType()); if (!IsMultiDimArr) { assert(GEPChainUseCount < FlattenedArrayType->getNumElements()); @@ -316,9 +339,12 @@ bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) { // Handle zero uses here because there won't be an update via // a child in the chain later. if (GEPChainUseCount == 0) { - SmallVector Indices({GEP.getOperand(GEP.getNumOperands() - 1)}); - SmallVector Dims({ArrType->getNumElements()}); - bool AllIndicesAreConstInt = isa(Indices[0]); + SmallVector Indices; + SmallVector Dims; + bool AllIndicesAreConstInt = true; + + // Collect indices and dimensions from the GEP + collectIndicesAndDimsFromGEP(GEP, Indices, Dims, AllIndicesAreConstInt); GEPData GEPInfo{std::move(FlattenedArrayType), PtrOperand, std::move(Indices), std::move(Dims), AllIndicesAreConstInt}; return visitGetElementPtrInstInGEPChainBase(GEPInfo, GEP); diff --git a/llvm/test/CodeGen/DirectX/flatten-array.ll b/llvm/test/CodeGen/DirectX/flatten-array.ll index 5c761014d471f..dc8c5f8421bfe 100644 --- a/llvm/test/CodeGen/DirectX/flatten-array.ll +++ b/llvm/test/CodeGen/DirectX/flatten-array.ll @@ -187,5 +187,75 @@ define void @global_gep_store() { ret void } +@g = local_unnamed_addr addrspace(3) global [2 x [2 x float]] zeroinitializer, align 4 +define void @two_index_gep() { + ; CHECK-LABEL: define void @two_index_gep( + ; CHECK: [[THREAD_ID:%.*]] = tail call i32 @llvm.dx.thread.id(i32 0) + ; CHECK-NEXT: [[MUL:%.*]] = mul i32 [[THREAD_ID]], 2 + ; CHECK-NEXT: [[ADD:%.*]] = add i32 1, [[MUL]] + ; CHECK-NEXT: [[GEP_PTR:%.*]] = getelementptr inbounds nuw [4 x float], ptr addrspace(3) @g.1dim, i32 0, i32 [[ADD]] + ; CHECK-NEXT: load float, ptr addrspace(3) [[GEP_PTR]], align 4 + ; CHECK-NEXT: ret void + %1 = tail call i32 @llvm.dx.thread.id(i32 0) + %2 = getelementptr inbounds nuw [2 x [2 x float]], ptr addrspace(3) @g, i32 0, i32 %1, i32 1 + %3 = load float, ptr addrspace(3) %2, align 4 + ret void +} + +define void @two_index_gep_const() { + ; CHECK-LABEL: define void @two_index_gep_const( + ; CHECK-NEXT: [[GEP_PTR:%.*]] = getelementptr inbounds nuw [4 x float], ptr addrspace(3) @g.1dim, i32 0, i32 3 + ; CHECK-NEXT: load float, ptr addrspace(3) [[GEP_PTR]], align 4 + ; CHECK-NEXT: ret void + %1 = getelementptr inbounds nuw [2 x [2 x float]], ptr addrspace(3) @g, i32 0, i32 1, i32 1 + %3 = load float, ptr addrspace(3) %1, align 4 + ret void +} + +define void @gep_4d_index_test() { + ; CHECK-LABEL: gep_4d_index_test + ; CHECK: [[a:%.*]] = alloca [16 x i32], align 4 + ; CHECK-NEXT: getelementptr inbounds [16 x i32], ptr %.1dim, i32 0, i32 1 + ; CHECK-NEXT: getelementptr inbounds [16 x i32], ptr %.1dim, i32 0, i32 3 + ; CHECK-NEXT: getelementptr inbounds [16 x i32], ptr %.1dim, i32 0, i32 7 + ; CHECK-NEXT: getelementptr inbounds [16 x i32], ptr %.1dim, i32 0, i32 15 + ; CHECK-NEXT: ret void + %1 = alloca [2x[2 x[2 x [2 x i32]]]], align 4 + %2 = getelementptr inbounds [2 x [2 x[2 x [2 x i32]]]], [2 x [2 x [2 x [2 x i32]]]]* %1, i32 0, i32 0, i32 0, i32 0, i32 1 + %3 = getelementptr inbounds [2 x [2 x[2 x [2 x i32]]]], [2 x [2 x [2 x [2 x i32]]]]* %1, i32 0, i32 0, i32 0, i32 1, i32 1 + %4 = getelementptr inbounds [2 x [2 x[2 x [2 x i32]]]], [2 x [2 x [2 x [2 x i32]]]]* %1, i32 0, i32 0, i32 1, i32 1, i32 1 + %5 = getelementptr inbounds [2 x [2 x[2 x [2 x i32]]]], [2 x [2 x [2 x [2 x i32]]]]* %1, i32 0, i32 1, i32 1, i32 1, i32 1 + ret void +} + +define void @gep_4d_index_and_gep_chain_mixed() { + ; CHECK-LABEL: gep_4d_index_and_gep_chain_mixed + ; CHECK-NEXT: [[ALLOCA:%.*]] = alloca [16 x i32], align 4 + ; CHECK-COUNT-16: getelementptr inbounds [16 x i32], ptr [[ALLOCA]], i32 0, i32 {{[0-9]|1[0-5]}} + ; CHECK-NEXT: ret void + %1 = alloca [2x[2 x[2 x [2 x i32]]]], align 4 + %a4d0_0 = getelementptr inbounds [2 x [2 x [2 x [2 x i32]]]], [2 x [2 x[2 x [2 x i32]]]]* %1, i32 0, i32 0, i32 0 + %a2d0_0 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %a4d0_0, i32 0, i32 0, i32 0 + %a2d0_1 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %a4d0_0, i32 0, i32 0, i32 1 + %a2d1_0 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %a4d0_0, i32 0, i32 1, i32 0 + %a2d1_1 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %a4d0_0, i32 0, i32 1, i32 1 + %b4d0_1 = getelementptr inbounds [2 x [2 x [2 x [2 x i32]]]], [2 x [2 x [2 x [2 x i32]]]]* %1, i32 0, i32 0, i32 1 + %b2d0_0 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %b4d0_1, i32 0, i32 0, i32 0 + %b2d0_1 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %b4d0_1, i32 0, i32 0, i32 1 + %b2d1_0 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %b4d0_1, i32 0, i32 1, i32 0 + %b2d1_1 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %b4d0_1, i32 0, i32 1, i32 1 + %c4d1_0 = getelementptr inbounds [2 x [2 x [2 x [2 x i32]]]], [2 x [2 x [2 x [2 x i32]]]]* %1, i32 0, i32 1, i32 0 + %c2d0_0 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %c4d1_0, i32 0, i32 0, i32 0 + %c2d0_1 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %c4d1_0, i32 0, i32 0, i32 1 + %c2d1_0 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %c4d1_0, i32 0, i32 1, i32 0 + %c2d1_1 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %c4d1_0, i32 0, i32 1, i32 1 + %g4d1_1 = getelementptr inbounds [2 x [2 x [2 x [2 x i32]]]], [2 x [2 x [2 x [2 x i32]]]]* %1, i32 0, i32 1, i32 1 + %g2d0_0 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %g4d1_1, i32 0, i32 0, i32 0 + %g2d0_1 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %g4d1_1, i32 0, i32 0, i32 1 + %g2d1_0 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %g4d1_1, i32 0, i32 1, i32 0 + %g2d1_1 = getelementptr inbounds [2 x [2 x i32]], [2 x [2 x i32]]* %g4d1_1, i32 0, i32 1, i32 1 + ret void +} + ; Make sure we don't try to walk the body of a function declaration. declare void @opaque_function() diff --git a/llvm/test/CodeGen/DirectX/llc-vector-load-scalarize.ll b/llvm/test/CodeGen/DirectX/llc-vector-load-scalarize.ll index c960aad3d2627..778113bd3160f 100644 --- a/llvm/test/CodeGen/DirectX/llc-vector-load-scalarize.ll +++ b/llvm/test/CodeGen/DirectX/llc-vector-load-scalarize.ll @@ -111,13 +111,13 @@ define <4 x i32> @multid_load_test() #0 { ; CHECK-NEXT: [[TMP6:%.*]] = load i32, ptr addrspace(3) [[TMP5]], align 4 ; CHECK-NEXT: [[TMP7:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 3) to ptr addrspace(3) ; CHECK-NEXT: [[TMP8:%.*]] = load i32, ptr addrspace(3) [[TMP7]], align 4 -; CHECK-NEXT: [[TMP9:%.*]] = bitcast ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 0, i32 1) to ptr addrspace(3) +; CHECK-NEXT: [[TMP9:%.*]] = bitcast ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 0, i32 4) to ptr addrspace(3) ; CHECK-NEXT: [[TMP10:%.*]] = load i32, ptr addrspace(3) [[TMP9]], align 4 -; CHECK-NEXT: [[TMP11:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 0, i32 1), i32 1) to ptr addrspace(3) +; CHECK-NEXT: [[TMP11:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 0, i32 4), i32 1) to ptr addrspace(3) ; CHECK-NEXT: [[TMP12:%.*]] = load i32, ptr addrspace(3) [[TMP11]], align 4 -; CHECK-NEXT: [[TMP13:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 0, i32 1), i32 2) to ptr addrspace(3) +; CHECK-NEXT: [[TMP13:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 0, i32 4), i32 2) to ptr addrspace(3) ; CHECK-NEXT: [[TMP14:%.*]] = load i32, ptr addrspace(3) [[TMP13]], align 4 -; CHECK-NEXT: [[TMP15:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 0, i32 1), i32 3) to ptr addrspace(3) +; CHECK-NEXT: [[TMP15:%.*]] = bitcast ptr addrspace(3) getelementptr (i32, ptr addrspace(3) getelementptr inbounds ([36 x i32], ptr addrspace(3) @groushared2dArrayofVectors.scalarized.1dim, i32 0, i32 4), i32 3) to ptr addrspace(3) ; CHECK-NEXT: [[TMP16:%.*]] = load i32, ptr addrspace(3) [[TMP15]], align 4 ; CHECK-NEXT: [[DOTI05:%.*]] = add i32 [[TMP2]], [[TMP10]] ; CHECK-NEXT: [[DOTI16:%.*]] = add i32 [[TMP4]], [[TMP12]]