Skip to content

Add utility functions for device management and input validation#225

Merged
LoserCheems merged 7 commits intomainfrom
optime-triton-kernels
Feb 8, 2026
Merged

Add utility functions for device management and input validation#225
LoserCheems merged 7 commits intomainfrom
optime-triton-kernels

Conversation

@LoserCheems
Copy link
Copy Markdown
Collaborator

Summary

  • Introduces utility functions for managing devices and validating tensor inputs.

Root Cause

  • Enhancements to improve clarity and consistency in device management and input validation.

Changes

  • Added functions for device detection, architecture retrieval, and input validation.
  • Refactored autotuning configuration to utilize new utility functions.
  • Initialized output tensors to zero for proper handling in the forward base kernel.

Reproduction

  • No specific bug to reproduce; enhancements made for better functionality.

Tests

  • Validated functionality through existing tests and ensured no regressions.

Compatibility

  • No backward compatibility issues; new functions added without altering existing interfaces.

Checklist

Copilot AI review requested due to automatic review settings February 1, 2026 12:09
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces utility functions for device management, input validation, and autotuning configuration to support a Triton-based multi-platform attention implementation (related to issue #222). The changes split the monolithic _flash_attn_forward function into separate base and varlen variants while extracting common utilities.

Changes:

  • Added a new utils.py module with device detection, architecture retrieval, autotuning configuration generation, and input validation functions
  • Refactored flash_fwd.py to use the new utility functions and split the forward pass into _flash_attn_base_forward and _flash_attn_varlen_base_forward functions
  • Enabled CUDA graph support in the autotune decorator

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 15 comments.

File Description
flash_sparse_attn/ops/triton/utils.py New utility module providing device detection, architecture identification, autotuning configuration, grid generation, and input validation functions
flash_sparse_attn/ops/triton/flash_fwd.py Refactored to use new utilities, split forward pass into separate base and varlen functions, and enabled CUDA graph support

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@LoserCheems LoserCheems merged commit 9db5fec into main Feb 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants