Skip to content

GroupedMmaOp::evaluate can't be captured by CUDA graph #5434

@wujingyue

Description

@wujingyue

Repro

Lightning-AI/lightning-thunder#2697 (comment)

Root cause

gdb points to the call to .cpu() in GroupedMmaOp's fallback implementation.

(gdb) f 19
#19 0x00007ffe06eef393 in nvfuser::GroupedMmaOp::evaluate(nvfuser::ExpressionEvaluator const&, std::vector<dynamic_type::DynamicType<dynamic_type::Containers<std::vector>, nvfuser::StructHandle, nvfuser::Pointer, nvfuser::Opaque, at::Tensor, std::complex<double>, double, long, bool>, std::allocator<dynamic_type::DynamicType<dynamic_type::Containers<std::vector>, nvfuser::StructHandle, nvfuser::Pointer, nvfuser::Opaque, at::Tensor, std::complex<double>, double, long, bool> > > const&) const::$_0::operator()() const (this=0x7fffffff5fa0) at /opt/pytorch/nvfuser/csrc/ir/nodes.cpp:5835
5835        at::Tensor offsets_cpu = offsets.cpu();
(gdb) l
5830        NVF_ERROR(!alpha.defined(), "alpha is not supported yet");
5831        NVF_ERROR(!beta.defined(), "beta is not supported yet");
5832        NVF_ERROR(!bias.defined(), "bias is not supported yet");
5833
5834        // Compute numbers of tokens per group from offsets.
5835        at::Tensor offsets_cpu = offsets.cpu();
5836        NVF_ERROR_EQ(offsets_cpu.dtype(), at::kInt);
5837        const int* data_ptr = offsets_cpu.data_ptr<int>();
5838        const int64_t num_groups = offsets_cpu.numel();
5839        std::vector<int64_t> group_sizes(data_ptr, data_ptr + num_groups);

It's a bit hard to find this line, because my libtorch doesn't come with debug symbols and the default catch throw is triggered by N8pybind1114stop_iterationE too ofen.

So I had to catch throw and commands the following script:

silent
set $tinfo   = (const std::type_info*)( $rsi )
set $mangled = $tinfo->name()
set $status  = 0
set $pretty  = (char*)__cxxabiv1::__cxa_demangle($mangled, 0, 0, &$status)
if ($status == 0 && $pretty && strstr($pretty, "c10::AcceleratorError"))
  printf ">>> stopping on %s\n", $pretty
  bt full
else
  continue
end
end

With that, I got a catchpoint only for c10::AcceleratorError

(gdb) info b
Num     Type           Disp Enb Address            What
4       catchpoint     keep y                      exception throw
        catchpoint already hit 1393 times
        silent
        set $namep = (char**)( $rsi )
        set $namep = $namep + 1
        set $mangled = *$namep
        if ($mangled && ((int)strcmp($mangled, "N3c1016AcceleratorErrorE") == 0))
          printf ">>> stopping on %s\n", $mangled
          bt full
        else
          continue
        end

Metadata

Metadata

Assignees

No one assigned

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions