Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4286,8 +4286,9 @@ def kernel(in_base_ptr, out_base_ptr, IN_SHAPE0: tl.constexpr, IN_SHAPE1: tl.con
actual = torch.zeros(expected.shape, dtype=torch.int32, device=device)

k = kernel[(1, )](input, actual, shape[0], shape[1])
assert k.asm['ttgir'].count(
'triton_gpu.convert_layout') == 1, "Expected exactly one convert_layout op in the TTGIR after optimization"
if not is_xpu():
assert k.asm['ttgir'].count(
'triton_gpu.convert_layout') == 1, "Expected exactly one convert_layout op in the TTGIR after optimization"

np.testing.assert_equal(to_numpy(expected), to_numpy(actual))

Expand Down
68 changes: 58 additions & 10 deletions test/TritonIntelGPU/rewrite-tensor-pointer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ module attributes {"triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-wa
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
%22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #dot1>>
%23:3 = scf.for %arg10 = %c0_i32 to %arg6 step %c32_i32 iter_args(%arg11 = %cst, %arg12 = %18, %arg13 = %22) -> (tensor<256x256xf32, #dpas>, !tt.ptr<tensor<256x32xf16, #dot0>>, !tt.ptr<tensor<32x256xf16, #dot1>>) : i32 {
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
%28 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xf16, #dot0>>
%29 = tt.load %arg13 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #dot1>>
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
%28 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<256x32xf16, #dot0>>
%29 = tt.load %arg13 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #dot1>>
// CHECK: tt.dot {{.*}}, {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>> * tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>> -> tensor<256x256xf32, #[[DPAS]]>
// CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
// CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
Expand All @@ -59,8 +59,8 @@ module attributes {"triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-wa
%25 = arith.extsi %arg9 : i32 to i64
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<256x256xf32, #[[DPAS]]>>
%26 = tt.make_tensor_ptr %arg3, [%15, %20], [%25, %c1_i64], [%14, %19] {order = array<i32: 1, 0>} : <tensor<256x256xf32, #dpas>>
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x256xf32, #[[DPAS]]>>
%27 = tt.load %26 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x256xf32, #dpas>>
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<256x256xf32, #[[DPAS]]>>
%27 = tt.load %26 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<256x256xf32, #dpas>>
%28 = arith.addf %23#0, %27 : tensor<256x256xf32, #dpas>
%29 = arith.truncf %28 : tensor<256x256xf32, #dpas> to tensor<256x256xf16, #dpas>

Expand Down Expand Up @@ -125,10 +125,10 @@ module attributes {"triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-wa
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
%22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #dot1>>
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<256x256xf32, #dpas>, !tt.ptr<tensor<256x32xf16, #dot0>>, !tt.ptr<tensor<32x256xf16, #dot1>>) : i32 {
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xf16, #dot0>>
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #dot1>>
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<256x32xf16, #dot0>>
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #dot1>>
// CHECK: tt.dot {{.*}}, {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>> * tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>> -> tensor<256x256xf32, #[[DPAS]]>
// CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
// CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
Expand Down Expand Up @@ -335,3 +335,51 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32
tt.return
}
}

// -----

// COM: Case 5:
// COM: Check that a make tensor ptr with no loads is properly removed
// CHECK: #[[DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_bf16_conversion, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} {
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
// CHECK: @matmul_kernel_with_block_pointers
%c4_i32 = arith.constant 4 : i32
%c256_i32 = arith.constant 256 : i32
%c1024_i64 = arith.constant 1024 : i64
%c5120_i64 = arith.constant 5120 : i64
%c1_i64 = arith.constant 1 : i64
%c0_i32 = arith.constant 0 : i32
%c4096_i64 = arith.constant 4096 : i64
%c32_i32 = arith.constant 32 : i32
%c64_i32 = arith.constant 64 : i32
%c5120_i32 = arith.constant 5120 : i32
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #dpas>
%0 = tt.get_program_id x : i32
%1 = arith.divsi %0, %c64_i32 : i32
%2 = arith.muli %1, %c4_i32 : i32
%3 = arith.subi %c4_i32, %2 : i32
%4 = arith.minsi %3, %c4_i32 : i32
%5 = arith.remsi %0, %4 : i32
%6 = arith.addi %2, %5 : i32
%7 = arith.remsi %0, %c64_i32 : i32
%8 = arith.divsi %7, %4 : i32
%9 = arith.muli %6, %c256_i32 : i32
// CHECK-NOT: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
%10 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c5120_i64], [%c5120_i64, %c1_i64], [%9, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth = 2}>>>
%11 = arith.muli %8, %c256_i32 : i32
// CHECK-NOT: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 0, 1>} : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
%12 = tt.make_tensor_ptr %arg1, [%c5120_i64, %c4096_i64], [%c1_i64, %c5120_i64], [%c0_i32, %11] {order = array<i32: 0, 1>} : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>
%13:3 = scf.for %arg3 = %c0_i32 to %c5120_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %10, %arg6 = %12) -> (tensor<256x256xf32, #dpas>, !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth = 2}>>>, !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>) : i32 {
// CHECK-NOT: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
// CHECK-NOT: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
%19 = tt.advance %arg5, [%c0_i32, %c32_i32] : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth = 2}>>>
%20 = tt.advance %arg6, [%c32_i32, %c0_i32] : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>
scf.yield %arg4, %19, %20 : tensor<256x256xf32, #dpas>, !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth = 2}>>>, !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>
}
%14 = tt.make_tensor_ptr %arg2, [%c1024_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%9, %11] {order = array<i32: 1, 0>} : <tensor<256x256xf16, #dpas>>
%15 = arith.truncf %13#0 : tensor<256x256xf32, #dpas> to tensor<256x256xf16, #dpas>
tt.return
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ struct TritonIntelGPUMaterializeBlockPointerPass
return;
}

if (fastChangeDim == rank - 2 &&
tensorType.getElementTypeBitWidth() == 8) {
// TODO: column major layout w/ fp8 has performance regression
return;
}

if (fastChangeDim >= (rank - 2)) {
// HW 2D block read instruction only supports contiguous access.
Value fastChangeStride = strides[fastChangeDim];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ namespace {
/// - it does not have Dpas layout or Dot layout (with Dpas layout as parent)
/// - its pitch is not divisible by Qword bitwidth
/// - it is not contiguous in memory
bool shouldRemove(tt::MakeTensorPtrOp &op, bool isUsedByLoadOrStoreOp) {
bool shouldRemove(tt::MakeTensorPtrOp &op, const bool isUsedByStoreOp,
const bool isUsedByBlockLoadOp) {
LDBG("Considering removal of: " << op);
if (!op->getParentOfType<ModuleOp>()->hasAttr(
ttgi::TritonIntelGPUDialect::getSupportSG2DBlockAttrName())) {
Expand All @@ -45,61 +46,19 @@ bool shouldRemove(tt::MakeTensorPtrOp &op, bool isUsedByLoadOrStoreOp) {
LDBG("Op ptr type: " << ptrType);
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
LDBG("Op tensor type: " << tensorType);

if (!ttgi::hasDotDpasEncoding(tensorType) &&
!(isUsedByLoadOrStoreOp && ttgi::hasDpasEncoding(tensorType))) {
LDBG("Marked for removal: tensor doesn't have DPAS layout and is not used "
"by load or store op with DPAS layout");
return true;
}

TypedValue<triton::PointerType> base = op.getBase();
Operation::operand_range shape = op.getShape();
unsigned rank = shape.size();
assert(rank > 1 && "Expecting tensor with rank > 1");
Operation::operand_range strides = op.getStrides();
Operation::operand_range offsets = op.getOffsets();
ArrayRef<int32_t> order = op.getOrder();
ArrayRef<int64_t> tensorShape = tensorType.getShape();

int fastChangeDim = -1;
for (size_t i = 0; i < strides.size(); ++i) {
if (ttgi::isConstant(strides[i], 1)) {
fastChangeDim = i;
break;
}
}

LDBG("fastChangeDim: " << fastChangeDim);
if (fastChangeDim < 0) {
LDBG("Marked for removal: fast changing dimension not found");
return true;
}

LDBG("Tensor type element type bit width: "
<< tensorType.getElementTypeBitWidth());
if (fastChangeDim == rank - 2 && tensorType.getElementTypeBitWidth() == 8) {
// TODO: column major layout w/ fp8 has performance regression
LDBG("Marked for removal: column major layout with fp8 element type");
return true;
}

// HW 2D block read instruction has restriction on pitch divisibility
if (fastChangeDim >= (rank - 2)) {
auto pitch = strides[(fastChangeDim == rank - 1) ? rank - 2 : rank - 1];
LDBG("Pitch: " << pitch);
// Across Intel platforms, the strictest pitch restriction is to be a
// multiple of OWord(128 bits).
if (!ttgi::isDivisible(pitch, 128 / tensorType.getElementTypeBitWidth())) {
LDBG("Marked for removal: cannot use block read/write instructions");
return true;
}

LDBG("Used by store op? " << isUsedByStoreOp);
LDBG("Used by block load op? " << isUsedByBlockLoadOp);

LDBG("hasDpasEncoding: " << ttgi::hasDpasEncoding(tensorType));
if (isUsedByBlockLoadOp ||
(isUsedByStoreOp && ttgi::hasDpasEncoding(tensorType))) {
LDBG("Tensor has DPAS layout or is used by load/store op with DPAS layout, "
"skipping removal");
return false;
}

LDBG("Marked for removal: fall-trough");

LDBG("Marked for removal: make tensor ptr op is not used by block load op or "
"by store op with DPAS layout");
return true;
}

Expand Down Expand Up @@ -715,28 +674,73 @@ class TritonIntelGPURewriteTensorPointerPass
void runOnOperation() override {
ModuleOp mod = getOperation();

auto usedByLoadOrStoreOp = [](Value val) {
return llvm::any_of(val.getUsers(), [](Operation *user) {
return isa<tt::LoadOp, tt::StoreOp>(user);
});
};
DenseSet<Operation *> tensorPointersToRemove;
mod.walk([&](tt::MakeTensorPtrOp makeTensorPtrOp) {
tensorPointersToRemove.insert(makeTensorPtrOp);
DenseSet<Operation *> workingSet;

auto markTensorPointerForRemoval =
[this](Value val, bool isUsedByLoadOrStoreOp = false) {
if (tt::isTensorPointerType(val.getType())) {
tt::MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp(val);
if (shouldRemove(makeTensorPtrOp, isUsedByLoadOrStoreOp))
valueToRemove.insert(val);
LDBG("Considering: " << makeTensorPtrOp);
Value result = makeTensorPtrOp.getResult();
for (auto user : result.getUsers()) {
workingSet.insert(user);
}
while (!workingSet.empty()) {
auto crtOpItr = workingSet.begin();
auto crtOp = *crtOpItr;
LDBG("Processing op: " << *crtOp);
if (isa<tt::LoadOp, tt::StoreOp>(crtOp)) {
if (!shouldRemove(
makeTensorPtrOp,
/*isUsedByStoreOp=*/isa<tt::StoreOp>(crtOp),
/*isBlockLoad=*/
isa<tt::LoadOp>(crtOp) &&
crtOp->hasAttr(
ttgi::TritonIntelGPUDialect::getBlockIOAttrName()))) {
tensorPointersToRemove.erase(makeTensorPtrOp);
return WalkResult::advance();
}
};
} else if (auto forOp = dyn_cast<scf::ForOp>(crtOp)) {
for (auto [arg, blockArg] :
llvm::zip(forOp.getInitArgs(),
forOp.getBody()->getArguments().drop_front(
forOp.getNumInductionVars()))) {
if (arg == makeTensorPtrOp) {
// add users of block arg
for (auto user : blockArg.getUsers()) {
workingSet.insert(user);
}
}
}
} else if (crtOp->getNumResults() > 0) {
// TODO: should we handle more than one result?
auto crtOpResult = crtOp->getResult(0);
LDBG("Not a load store and not a loop, adding users to working "
"set.");
for (auto user : crtOpResult.getUsers()) {
workingSet.insert(user);
}
}
workingSet.erase(crtOpItr);
}
return WalkResult::advance();
});

auto markTensorPointerForRemoval = [this,
&tensorPointersToRemove](Value val) {
if (tt::isTensorPointerType(val.getType())) {
tt::MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp(val);
if (tensorPointersToRemove.count(makeTensorPtrOp)) {
valueToRemove.insert(val);
}
}
};

mod.walk([&](Operation *op) {
if (isa<tt::MakeTensorPtrOp>(op)) {
Value result = op->getResult(0);
markTensorPointerForRemoval(result, usedByLoadOrStoreOp(result));
markTensorPointerForRemoval(result);
} else if (isa<tt::AdvanceOp, tt::LoadOp, tt::StoreOp>(op)) {
markTensorPointerForRemoval(op->getOperand(0),
isa<tt::LoadOp, tt::StoreOp>(op));
markTensorPointerForRemoval(op->getOperand(0));
} else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
for (auto arg : forOp.getInitArgs())
markTensorPointerForRemoval(arg);
Expand All @@ -752,7 +756,7 @@ class TritonIntelGPURewriteTensorPointerPass
else {
DBGS() << "Values to remove: ";
for (auto val : valueToRemove)
DBGS() << val;
DBGS() << val << "\n";
}
});

Expand Down