Skip to content

Commit d3a4b63

Browse files
committed
fix lint
1 parent eef906c commit d3a4b63

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

aten/src/ATen/Context.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -689,16 +689,18 @@ struct TORCH_API NoTF32Guard {
689689
bool changed = false;
690690
};
691691

692-
template<Float32Backend target_backend, Float32Op target_op>
692+
template <Float32Backend target_backend, Float32Op target_op>
693693
struct Fp32PrecisonGuard {
694694
Fp32PrecisonGuard(const Float32Precision new_precision) {
695695
if (new_precision == Float32Precision::NONE) {
696-
return ;
696+
return;
697697
}
698-
saved_precision = globalContext().float32Precision(target_backend, target_op);
698+
saved_precision =
699+
globalContext().float32Precision(target_backend, target_op);
699700
changed = (new_precision != saved_precision);
700701
if (changed) {
701-
globalContext().setFloat32Precision(target_backend, target_op, new_precision);
702+
globalContext().setFloat32Precision(
703+
target_backend, target_op, new_precision);
702704
}
703705
}
704706
Fp32PrecisonGuard(Fp32PrecisonGuard&& other) = delete;
@@ -707,9 +709,11 @@ struct Fp32PrecisonGuard {
707709
Fp32PrecisonGuard& operator=(Fp32PrecisonGuard&&) = delete;
708710
~Fp32PrecisonGuard() {
709711
if (changed) {
710-
globalContext().setFloat32Precision(target_backend, target_op, saved_precision);
712+
globalContext().setFloat32Precision(
713+
target_backend, target_op, saved_precision);
711714
}
712715
}
716+
713717
private:
714718
Float32Precision saved_precision;
715719
bool changed = false;

torch/backends/cuda/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def __setattr__(self, name, value):
205205
return torch._C._set_fp32_precision_setter("cuda", "matmul", value)
206206
raise AttributeError("Unknown attribute " + name)
207207

208+
208209
class MathSDPModule:
209210
def __getattr__(self, name):
210211
if name == "fp32_precision":

0 commit comments

Comments
 (0)