Skip to content

[BUG] <CUDA>crash on implicit gemm kernel #361

@lgyStoic

Description

@lgyStoic

Is there an existing issue for this?

  • I have searched the existing issues

Current Behavior

I write a minimal demo to reproduce this crash

import torch
import torchsparse
from torchsparse.nn import Conv3d
# Test case for torchsparse
import numpy as np
from torch.utils.checkpoint import checkpoint

def test_torchsparse_convolution_large_scale():
    """Test torchsparse convolution with large scale data"""
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Parameters
    # batch_size = 16777216
    # batch_size = 16777217
    # batch_size = 16777216
    batch_size = 17779768
    feat_dim = 64
    coord_dim = 4
    kernel_size = 3
    dilation = 1
    stride = 1
    
    try:
        # Create large scale sparse tensor
        
        # Generate random coordinates (batch_size x coord_dim)
        coords = torch.randint(0, 1000, (batch_size, coord_dim), device=device, dtype=torch.int32)
        
        # Generate random features (batch_size x feat_dim)
        feats = torch.randn(batch_size, feat_dim, device=device, dtype=torch.float16, requires_grad=True)
        
        # Create sparse tensor
        sparse_tensor = torchsparse.SparseTensor(coords=coords, feats=feats)
        in_cha = 64
        out_cha = 32
        conv_layer = torchsparse.nn.Conv3d(
            in_channels=in_cha,
            out_channels=out_cha,
            kernel_size=kernel_size,
            stride=stride,
            padding = 0,
            dilation=dilation,
            transposed=False,
            bias=True
        ).to(device)
        conf_func = torchsparse.nn.functional.conv3d

        compute_type = feats.dtype
        def cast(p: torch.Tensor) -> torch.Tensor:
            return p.to(compute_type)
        conv_kernel = checkpoint(cast, conv_layer.kernel, use_reentrant=True)
        conv_bias = (
            checkpoint(cast, conv_layer.bias, use_reentrant=True)
            if conv_layer.bias is not None
            else None
        )
        # Forward pass
        print("Running forward pass...")
        output = conf_func(sparse_tensor, 
                    weight = conv_kernel,
                    kernel_size = conv_layer.kernel_size,
                    bias = conv_bias,
                    stride = conv_layer.stride,
                    padding = conv_layer.padding,
                    dilation = conv_layer.dilation,
                    transposed = conv_layer.transposed,
                    generative = conv_layer.generative,
                    config = conv_layer._config,
                    training = conv_layer.training,
                    )
        
        torch.cuda.synchronize()
        
        # Backward pass
        print("Running backward pass...")
        test_loss = output.feats.sum()
        test_loss.backward()
        
        torch.cuda.synchronize()
        return output
    except Exception as e:
        print(f"✗ Error in large scale test: {e}")
        import traceback
        traceback.print_exc()
        return None


if __name__ == "__main__":
    print("Running torchsparse tests...")
    
    test_torchsparse_convolution_large_scale()
    print("All tests completed!")

when the batch size increase, this code may cause crash on backward

Image I debug this repo, finially I found crash on this cuda kernel 'conv_forward_cuda_setting3_mode1_f16f16f32' because my datatype is fp16; this implicit gemm logic is quit complicate. could you give some hint or reasonable answer? BTW, this code test on H100 and 4090, both cause crash.

Expected Behavior

No response

Environment

- GCC:10.5.0
- NVCC:cuda_12.4.r12.4/compiler.34097967_0
- PyTorch:2.7.1
- PyTorch CUDA:12.6
- TorchSparse:2.1.1

Anything else?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions