Skip to content

Conversation

Silv3S
Copy link
Contributor

@Silv3S Silv3S commented Oct 10, 2025

Summary

torch.special.logit for bfloat16 and float16 input runs in higher precision, because input is casted to AccumulateTypeDevice, which is float32 (pytorch/aten/src/ATen/AccumulateType.h). Output is casted back to lower precision, but because intermediate results are in float32, we have different results than CPU. It might affect other tests so I wanted to clarify if this is expected or we should always try to match CPU reference in our kernels.

Minimal repro

import torch
import pytest

@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["float32", "bfloat16"])
def test_special_logit(dtype):
    input_cpu = torch.tensor([0.5234], device="cpu", dtype=dtype)
    input_xpu = input_cpu.to("xpu")

    reference_cpu = torch.log(input_cpu/(1 - input_cpu))
    reference_xpu = torch.log(input_xpu/(1 - input_xpu))
    print(f"reference_cpu logit: {reference_cpu}")
    print(f"reference_xpu logit: {reference_xpu}")
    assert torch.allclose(reference_cpu, reference_xpu.cpu(), atol=1e-5, rtol=1e-5)

    logit_cpu = torch.special.logit(input_cpu)
    logit_xpu = torch.special.logit(input_xpu)
    print(f"CPU logit: {logit_cpu}")
    print(f"XPU logit: {logit_xpu}")
    assert torch.allclose(logit_cpu, logit_xpu.cpu(), atol=1e-5, rtol=1e-5)

Results

device dtype reference torch.special.logit torch.special.logit (fix)
CPU fp32 0.0937 0.0937
XPU fp32 0.0937 0.0937 0.0937
CUDA fp32 0.0937 0.0937
CPU bf16 0.0967 0.0967
XPU bf16 0.0967 0.0938 0.0967
CUDA bf16 0.0967 0.0938

@Copilot Copilot AI review requested due to automatic review settings October 10, 2025 11:56
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR fixes a precision issue with torch.special.logit for bfloat16 and float16 inputs by modifying the kernel to run computations in reduced precision instead of casting to higher precision (float32). The change ensures consistency between CPU and XPU device results for half-precision floating point types.

  • Simplified logit computation to use native input precision instead of accumulate type casting
  • Renamed functors for clarity (Logit0Functor → LogitFunctor, Logit1Functor → LogitEpsFunctor)
  • Updated parameter names and types to match the new precision-preserving approach

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Comment on lines +143 to +150
scalar_t x_clamped = x < low_ ? low_ : (x > high_ ? high_ : x);
return std::log(x_clamped / (1 - x_clamped));
}
Logit1Functor(const T_ACC lo, const T_ACC hi) : lo_(lo), hi_(hi) {}
LogitEpsFunctor(const T_ACC low, const T_ACC high) : low_(low), high_(high) {}

private:
T_ACC lo_;
T_ACC hi_;
scalar_t low_;
scalar_t high_;
Copy link

Copilot AI Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type mismatch: low_ and high_ are of type scalar_t but are being compared with x and assigned from T_ACC values in the constructor. This could cause precision loss or incorrect comparisons when scalar_t and T_ACC differ.

Copilot uses AI. Check for mistakes.

Comment on lines +143 to +150
scalar_t x_clamped = x < low_ ? low_ : (x > high_ ? high_ : x);
return std::log(x_clamped / (1 - x_clamped));
}
Logit1Functor(const T_ACC lo, const T_ACC hi) : lo_(lo), hi_(hi) {}
LogitEpsFunctor(const T_ACC low, const T_ACC high) : low_(low), high_(high) {}

private:
T_ACC lo_;
T_ACC hi_;
scalar_t low_;
scalar_t high_;
Copy link

Copilot AI Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Constructor parameters low and high are of type T_ACC but member variables low_ and high_ are of type scalar_t. This implicit conversion may cause precision loss when T_ACC has higher precision than scalar_t.

Copilot uses AI. Check for mistakes.

@Silv3S Silv3S changed the title Run torch.special.logit in reduced precision, for bf16/f16 inputs Run torch.special.logit in reduced precision for bf16/f16 inputs Oct 10, 2025
Copy link

@australopitek australopitek left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but we need to come up with the way of handling such discrepancies between CPU and CUDA results in future, and stick to it. Currently CPU gives different results than CUDA for these ops.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants