Skip to content

Add F.scaled_mm#2720

Merged
KaelanDt merged 17 commits intomainfrom
crpa/scaled_mm
Dec 8, 2025
Merged

Add F.scaled_mm#2720
KaelanDt merged 17 commits intomainfrom
crpa/scaled_mm

Conversation

@crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented Nov 6, 2025

What does this PR do?

As per title, this PR adds F.scaled_mm to thunder.torch and cover it with torchex impl.

Ref: https://docs.pytorch.org/docs/main/generated/torch.nn.functional.scaled_mm.html

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds support for torch.nn.functional.scaled_mm operation in Thunder. This operation performs scaled matrix multiplication, which is commonly used for FP8 quantized operations.

  • Adds a new scaled_mm function to thunder/torch/__init__.py with input validation and shape inference logic
  • Registers the implementation in thunder/executors/torchex.py to delegate to PyTorch
  • Adds comprehensive test coverage with tensor-wise, row-wise, and block-wise scaling tests

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

File Description
thunder/torch/init.py Implements the scaled_mm symbol with parameter validation and output shape/dtype inference
thunder/executors/torchex.py Registers the torch executor implementation for scaled_mm
thunder/tests/test_ops.py Adds test helper functions and comprehensive tests for scaled_mm with different scaling strategies

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@Lightning-AI Lightning-AI deleted a comment from kshitij12345 Nov 7, 2025
Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, just a few comments regarding the tests. Thanks!

Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @crcrpar

@crcrpar crcrpar requested a review from mattteochen November 27, 2025 06:25
Copy link
Collaborator

@mattteochen mattteochen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm @crcrpar

@crcrpar
Copy link
Collaborator Author

crcrpar commented Nov 28, 2025

@KaelanDt could you review this?

…wise scaling

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Copy link
Collaborator

@KaelanDt KaelanDt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you @crcrpar

@KaelanDt KaelanDt merged commit 006a243 into main Dec 8, 2025
50 checks passed
@KaelanDt KaelanDt deleted the crpa/scaled_mm branch December 8, 2025 11:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants