Skip to content

Commit 8284287

Browse files
committed
fix lint
1 parent 6e6c0c6 commit 8284287

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

aten/src/ATen/Context.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -695,16 +695,18 @@ struct TORCH_API NoTF32Guard {
695695
bool changed = false;
696696
};
697697

698-
template<Float32Backend target_backend, Float32Op target_op>
698+
template <Float32Backend target_backend, Float32Op target_op>
699699
struct Fp32PrecisonGuard {
700700
Fp32PrecisonGuard(const Float32Precision new_precision) {
701701
if (new_precision == Float32Precision::NONE) {
702-
return ;
702+
return;
703703
}
704-
saved_precision = globalContext().float32Precision(target_backend, target_op);
704+
saved_precision =
705+
globalContext().float32Precision(target_backend, target_op);
705706
changed = (new_precision != saved_precision);
706707
if (changed) {
707-
globalContext().setFloat32Precision(target_backend, target_op, new_precision);
708+
globalContext().setFloat32Precision(
709+
target_backend, target_op, new_precision);
708710
}
709711
}
710712
Fp32PrecisonGuard(Fp32PrecisonGuard&& other) = delete;
@@ -713,9 +715,11 @@ struct Fp32PrecisonGuard {
713715
Fp32PrecisonGuard& operator=(Fp32PrecisonGuard&&) = delete;
714716
~Fp32PrecisonGuard() {
715717
if (changed) {
716-
globalContext().setFloat32Precision(target_backend, target_op, saved_precision);
718+
globalContext().setFloat32Precision(
719+
target_backend, target_op, saved_precision);
717720
}
718721
}
722+
719723
private:
720724
Float32Precision saved_precision;
721725
bool changed = false;

torch/backends/cuda/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def __setattr__(self, name, value):
205205
return torch._C._set_fp32_precision_setter("cuda", "matmul", value)
206206
raise AttributeError("Unknown attribute " + name)
207207

208+
208209
class MathSDPModule:
209210
def __getattr__(self, name):
210211
if name == "fp32_precision":

0 commit comments

Comments
 (0)