@@ -170,26 +170,26 @@ def initialize_cuda_context_rng():
170170
171171@contextlib .contextmanager
172172def tf32_off ():
173- old_fp32_precision = torch .backends .cuda .matmul .fp32_precision
173+ old_allow_tf32_matmul = torch .backends .cuda .matmul .allow_tf32
174174 try :
175- torch .backends .cuda .matmul .fp32_precision = 'ieee'
175+ torch .backends .cuda .matmul .allow_tf32 = False
176176 with torch .backends .cudnn .flags (enabled = None , benchmark = None , deterministic = None , allow_tf32 = False ):
177177 yield
178178 finally :
179- torch .backends .cuda .matmul .fp32_precision = old_fp32_precision
179+ torch .backends .cuda .matmul .allow_tf32 = old_allow_tf32_matmul
180180
181181
182182@contextlib .contextmanager
183183def tf32_on (self , tf32_precision = 1e-5 ):
184- old_fp32_precision = torch .backends .cuda .matmul .fp32_precision
184+ old_allow_tf32_matmul = torch .backends .cuda .matmul .allow_tf32
185185 old_precision = self .precision
186186 try :
187- torch .backends .cuda .matmul .fp32_precision = 'tf32'
187+ torch .backends .cuda .matmul .allow_tf32 = True
188188 self .precision = tf32_precision
189189 with torch .backends .cudnn .flags (enabled = None , benchmark = None , deterministic = None , allow_tf32 = True ):
190190 yield
191191 finally :
192- torch .backends .cuda .matmul .fp32_precision = old_fp32_precision
192+ torch .backends .cuda .matmul .allow_tf32 = old_allow_tf32_matmul
193193 self .precision = old_precision
194194
195195
0 commit comments