Skip to content

Conversation

@misrasaurabh1
Copy link
Contributor

@misrasaurabh1 misrasaurabh1 commented Mar 14, 2025

PR Type

  • Enhancement
  • Tests

Description

  • Add torch dependency check and flag.

  • Implement torch.Tensor equality verification.

  • Include comprehensive torch tensor tests.


Changes walkthrough 📝

Relevant files
Enhancement
comparator.py
Integrate torch.Tensor comparisons.                                           

codeflash/verification/comparator.py

  • Introduced torch import within try/except.
  • Defined HAS_TORCH flag for dependency handling.
  • Added torch.Tensor specific comparison logic.
  • +12/-0   
    Tests
    test_comparator.py
    Add torch tensor comparison tests.                                             

    tests/test_comparator.py

  • Added test_torch function to verify tensor behavior.
  • Tested equality, shape, dtype, and NaN handling.
  • Ensured skipping tests when torch is unavailable.
  • +51/-0   

    Need help?
  • Type /help how to ... in the comments thread for any questions about PR-Agent usage.
  • Check out the documentation for more information.
  • @github-actions
    Copy link

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    ⏱️ Estimated effort to review: 3 🔵🔵🔵⚪⚪
    🧪 PR contains tests
    🔒 No security concerns identified
    ⚡ Recommended focus areas for review

    Type Safety

    The torch tensor comparison logic only checks if the original object is a torch.Tensor and then compares dtypes, shapes, and values without confirming that the new object is also a torch.Tensor. It would be safer to ensure type consistency between both objects.

    if HAS_TORCH and isinstance(orig, torch.Tensor):
        if orig.dtype != new.dtype:
            return False
        if orig.shape != new.shape:
            return False
        return torch.allclose(orig, new, equal_nan=True)

    @github-actions
    Copy link

    PR Code Suggestions ✨

    Explore these optional code suggestions:

    CategorySuggestion                                                                                                                                    Impact
    Possible issue
    Ensure tensor type consistency

    Insert a type-check for new to confirm it is a torch.Tensor before accessing its
    attributes.

    codeflash/verification/comparator.py [171-176]

     if HAS_TORCH and isinstance(orig, torch.Tensor):
    +    if not isinstance(new, torch.Tensor):
    +        return False
         if orig.dtype != new.dtype:
             return False
         if orig.shape != new.shape:
             return False
         return torch.allclose(orig, new, equal_nan=True)
    Suggestion importance[1-10]: 7

    __

    Why: The suggestion correctly adds a type-check for new as a torch.Tensor, preventing potential attribute errors, and is an appropriate enhancement to the robustness of the code.

    Medium

    @misrasaurabh1 misrasaurabh1 requested a review from alvin-r March 14, 2025 01:54
    @aseembits93
    Copy link
    Contributor

    aseembits93 commented Mar 14, 2025

    this looks good, torch.allclose will throw exceptions for all the edge cases (except requires_grad) you are checking for, you could simply do try, except for conciseness. Not too sure if we want to check for requires_grad

    return False
    if orig.device != new.device:
    return False
    return torch.allclose(orig, new, equal_nan=True)
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    torch.allclose with default params doesn't work as expected for tensors with low magnitude - might want to default to just using rtol?

    @misrasaurabh1 misrasaurabh1 disabled auto-merge March 14, 2025 02:27
    @aseembits93
    Copy link
    Contributor

    aseembits93 commented Mar 14, 2025

    import torch
    def torch_comparator(a,b):
        try:
            ans = torch.allclose(a,b,equal_nan=True)
        except RuntimeError:
            return False
        return ans
    a = torch.tensor([1,2,3],dtype=torch.float32)
    b = torch.tensor([1,2,3],dtype=torch.int64)
    print(torch_comparator(a,b))
    a = torch.tensor([1,2,3,4],dtype=torch.float32)
    b = torch.tensor([1,2,3],dtype=torch.float32)
    print(torch_comparator(a,b))
    a = torch.tensor([1,2,3],dtype=torch.float32).to(torch.device('mps'))
    b = torch.tensor([1,2,3],dtype=torch.float32)
    print(torch_comparator(a,b))
    a = torch.tensor([1,2,3],dtype=torch.float32)
    b = torch.tensor([1,2,3],dtype=torch.float32)
    print(torch_comparator(a,b))
    #wont work with requires_grad
    a = torch.tensor([1,2,3],dtype=torch.float32,requires_grad=True)
    b = torch.tensor([1,2,3],dtype=torch.float32,requires_grad=False)
    print(torch_comparator(a,b))
    

    @misrasaurabh1 try running this. here's the output

    False
    False
    False
    True
    True
    

    @misrasaurabh1 misrasaurabh1 enabled auto-merge March 14, 2025 05:15
    @misrasaurabh1 misrasaurabh1 merged commit 5ee8ad9 into main Mar 14, 2025
    15 checks passed
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    4 participants