|
22 | 22 |
|
23 | 23 | import torch |
24 | 24 |
|
25 | | -if TYPE_CHECKING: |
26 | | - from beartype.typing import List |
27 | 25 | from torch import Tensor |
28 | 26 | from torch.optim.optimizer import Optimizer |
29 | 27 |
|
@@ -242,12 +240,12 @@ def step(self, closure=None): |
242 | 240 |
|
243 | 241 |
|
244 | 242 | def _single_tensor_adan( |
245 | | - params: List[Tensor], |
246 | | - grads: List[Tensor], |
247 | | - exp_avgs: List[Tensor], |
248 | | - exp_avg_sqs: List[Tensor], |
249 | | - exp_avg_diffs: List[Tensor], |
250 | | - neg_pre_grads: List[Tensor], |
| 243 | + params: list[Tensor], |
| 244 | + grads: list[Tensor], |
| 245 | + exp_avgs: list[Tensor], |
| 246 | + exp_avg_sqs: list[Tensor], |
| 247 | + exp_avg_diffs: list[Tensor], |
| 248 | + neg_pre_grads: list[Tensor], |
251 | 249 | *, |
252 | 250 | beta1: float, |
253 | 251 | beta2: float, |
@@ -297,12 +295,12 @@ def _single_tensor_adan( |
297 | 295 |
|
298 | 296 |
|
299 | 297 | def _multi_tensor_adan( |
300 | | - params: List[Tensor], |
301 | | - grads: List[Tensor], |
302 | | - exp_avgs: List[Tensor], |
303 | | - exp_avg_sqs: List[Tensor], |
304 | | - exp_avg_diffs: List[Tensor], |
305 | | - neg_pre_grads: List[Tensor], |
| 298 | + params: list[Tensor], |
| 299 | + grads: list[Tensor], |
| 300 | + exp_avgs: list[Tensor], |
| 301 | + exp_avg_sqs: list[Tensor], |
| 302 | + exp_avg_diffs: list[Tensor], |
| 303 | + neg_pre_grads: list[Tensor], |
306 | 304 | *, |
307 | 305 | beta1: float, |
308 | 306 | beta2: float, |
@@ -356,12 +354,12 @@ def _multi_tensor_adan( |
356 | 354 |
|
357 | 355 |
|
358 | 356 | def _fused_adan_multi_tensor( |
359 | | - params: List[Tensor], |
360 | | - grads: List[Tensor], |
361 | | - exp_avgs: List[Tensor], |
362 | | - exp_avg_sqs: List[Tensor], |
363 | | - exp_avg_diffs: List[Tensor], |
364 | | - neg_pre_grads: List[Tensor], |
| 357 | + params: list[Tensor], |
| 358 | + grads: list[Tensor], |
| 359 | + exp_avgs: list[Tensor], |
| 360 | + exp_avg_sqs: list[Tensor], |
| 361 | + exp_avg_diffs: list[Tensor], |
| 362 | + neg_pre_grads: list[Tensor], |
365 | 363 | *, |
366 | 364 | beta1: float, |
367 | 365 | beta2: float, |
@@ -400,12 +398,12 @@ def _fused_adan_multi_tensor( |
400 | 398 |
|
401 | 399 |
|
402 | 400 | def _fused_adan_single_tensor( |
403 | | - params: List[Tensor], |
404 | | - grads: List[Tensor], |
405 | | - exp_avgs: List[Tensor], |
406 | | - exp_avg_sqs: List[Tensor], |
407 | | - exp_avg_diffs: List[Tensor], |
408 | | - neg_pre_grads: List[Tensor], |
| 401 | + params: list[Tensor], |
| 402 | + grads: list[Tensor], |
| 403 | + exp_avgs: list[Tensor], |
| 404 | + exp_avg_sqs: list[Tensor], |
| 405 | + exp_avg_diffs: list[Tensor], |
| 406 | + neg_pre_grads: list[Tensor], |
409 | 407 | *, |
410 | 408 | beta1: float, |
411 | 409 | beta2: float, |
|
0 commit comments