Skip to content

[BUG] Cute-dsl 4.2 gives wrong output for if statement #2647

@tridao

Description

@tridao

Which component has the problem?

CuTe DSL

Bug Report

In the example below, cute-dsl 4.2 gives wrong output, while cute-dsl 4.1 gives correct output.

Steps/Code to reproduce bug

import cutlass
import cutlass.cute as cute
from cutlass import Int32


class MyClass:

    def __init__(self, state: Int32, tensor: cute.Tensor, *, loc=None, ip=None):
        self.state = state
        self.tensor = tensor
        self._loc = loc
        self._ip = ip

    @cute.jit
    def mutate(self, *, loc=None, ip=None):
        state = 1
        if cute.arch.lane_idx() == 0:  # Without this, output is correct
            self.tensor[0] = self.state
        self.state = state

    def __extract_mlir_values__(self):
        values, self._values_pos = [], []
        for obj in [self.state, self.tensor]:
            obj_values = cutlass.extract_mlir_values(obj)
            values += obj_values
            self._values_pos.append(len(obj_values))
        return values

    def __new_from_mlir_values__(self, values):
        obj_list = []
        for obj, n_items in zip([self.state, self.tensor], self._values_pos):
            obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items]))
            values = values[n_items:]
        return self.__class__(*(tuple(obj_list)), loc=self._loc)

@cute.kernel
def kernel():
    tensor = cute.make_fragment(1, Int32)
    tidx, _, _ = cute.arch.thread_idx()
    cls = MyClass(Int32(0), tensor)
    cute.printf("Before mutate, tidx = {}, state = {}", tidx, cls.state)
    cls.mutate()
    cute.printf("After mutate, tidx = {}, state = {}", tidx, cls.state)


@cute.jit
def kernel_launch():
    cutlass.cuda.initialize_cuda_context()
    kernel().launch(
        grid=(1, 1, 1),
        block=(2, 1, 1),
    )

if __name__ == "__main__":
    kernel_launch()

With nvidia-cutlass-dsl 4.1, output is correct:

Before mutate, tidx = 0, state = 0
Before mutate, tidx = 1, state = 0
After mutate, tidx = 0, state = 1
After mutate, tidx = 1, state = 1

With nvidia-cutlass-dsl 4.2, output is wrong:

Before mutate, tidx = 0, state = 0
Before mutate, tidx = 1, state = 0
After mutate, tidx = 0, state = 0
After mutate, tidx = 1, state = 0

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions