@@ -145,10 +145,8 @@ def __init__(
145145 surf_stats (dict[str, tuple[float, float]], optional): For these surface-level
146146 variables, adjust the normalisation to the given tuple consisting of a new location
147147 and scale.
148- bf16_mode (bool, optional): To reduce memory usage, convert the tokens to BF16, run
149- the backbone in pure BF16, and run the decoder in FP16 AMP. This should enable a
150- gradient computation. USE AT YOUR OWN RISK. THIS WAS NOT USED DURING THE DEVELOPMENT
151- OF AURORA AND IS PURELY PROVIDED AS A STARTING POINT FOR FINE-TUNING.
148+ autocast (bool, optional): To reduce memory usage, `torch.autocast` only the backbone
149+ to BF16. This is critical to enable fine-tuning.
152150 level_condition (tuple[int | float, ...], optional): Make the patch embeddings dependent
153151 on pressure level. If you want to enable this feature, provide a tuple of all
154152 possible pressure levels.
@@ -228,6 +226,7 @@ def __init__(
228226 embed_dim = embed_dim ,
229227 mlp_ratio = mlp_ratio ,
230228 drop_path_rate = drop_path ,
229+ attn_drop_rate = drop_rate ,
231230 drop_rate = drop_rate ,
232231 use_lora = use_lora ,
233232 lora_steps = lora_steps ,
@@ -252,18 +251,16 @@ def __init__(
252251 modulation_heads = modulation_heads ,
253252 )
254253
255- if autocast and not bf16_mode :
254+ if bf16_mode and not autocast :
256255 warnings .warn (
257- "The argument `autocast` no longer does anything due to limited utility. "
258- "Consider instead using `bf16_mode`." ,
256+ "`bf16_mode` was removed, because it caused serious issues for gradient "
257+ "computation. `bf16_mode` now automatically activates `autocast`, which will not "
258+ "save as much memory, but should be much more stable." ,
259259 stacklevel = 2 ,
260260 )
261+ autocast = True
261262
262- self .bf16_mode = bf16_mode
263-
264- if self .bf16_mode :
265- # We run the backbone in pure BF16.
266- self .backbone .to (torch .bfloat16 )
263+ self .autocast = autocast
267264
268265 def forward (self , batch : Batch ) -> Batch :
269266 """Forward pass.
@@ -327,44 +324,30 @@ def forward(self, batch: Batch) -> Batch:
327324 lead_time = self .timestep ,
328325 )
329326
330- # In BF16 mode, the backbone is run in pure BF16.
331- if self .bf16_mode :
332- x = x .to (torch .bfloat16 )
333- x = self .backbone (
334- x ,
335- lead_time = self .timestep ,
336- patch_res = patch_res ,
337- rollout_step = batch .metadata .rollout_step ,
338- )
339-
340- # In BF16 mode, the decoder is run in AMP PF16, and the output is converted back to FP32.
341- # We run in PF16 as opposed to BF16 for improved relative precision.
342- if self .bf16_mode :
343- device_type = (
344- "cuda"
345- if torch .cuda .is_available ()
346- else "xpu"
347- if torch .xpu .is_available ()
348- else "cpu"
349- )
350- context = torch .autocast (device_type = device_type , dtype = torch .float16 )
351- x = x .to (torch .float16 )
327+ if self .autocast :
328+ if torch .cuda .is_available ():
329+ device_type = "cuda"
330+ elif torch .xpu .is_available ():
331+ device_type = "xpu"
332+ else :
333+ device_type = "cpu"
334+ context = torch .autocast (device_type = device_type , dtype = torch .bfloat16 )
352335 else :
353336 context = contextlib .nullcontext ()
354337 with context :
355- pred = self .decoder (
338+ x = self .backbone (
356339 x ,
357- batch ,
358340 lead_time = self .timestep ,
359341 patch_res = patch_res ,
342+ rollout_step = batch .metadata .rollout_step ,
360343 )
361- if self . bf16_mode :
362- pred = dataclasses . replace (
363- pred ,
364- surf_vars = { k : v . float () for k , v in pred . surf_vars . items ()} ,
365- static_vars = { k : v . float () for k , v in pred . static_vars . items ()} ,
366- atmos_vars = { k : v . float () for k , v in pred . atmos_vars . items ()} ,
367- )
344+
345+ pred = self . decoder (
346+ x ,
347+ batch ,
348+ lead_time = self . timestep ,
349+ patch_res = patch_res ,
350+ )
368351
369352 # Remove batch and history dimension from static variables.
370353 pred = dataclasses .replace (
@@ -520,27 +503,49 @@ def adapt_checkpoint_max_history_size(self, checkpoint: dict[str, torch.Tensor])
520503
521504 checkpoint [name ] = new_weight
522505
523- def configure_activation_checkpointing (self ):
506+ def configure_activation_checkpointing (
507+ self ,
508+ module_names : tuple [str , ...] = (
509+ "Basic3DDecoderLayer" ,
510+ "Basic3DEncoderLayer" ,
511+ "LinearPatchReconstruction" ,
512+ "Perceiver3DDecoder" ,
513+ "Perceiver3DEncoder" ,
514+ "Swin3DTransformerBackbone" ,
515+ "Swin3DTransformerBlock" ,
516+ ),
517+ ) -> None :
524518 """Configure activation checkpointing.
525519
526520 This is required in order to compute gradients without running out of memory.
521+
522+ Args:
523+ module_names (tuple[str, ...], optional): Names of the modules to checkpoint
524+ on.
525+
526+ Raises:
527+ RuntimeError: If any module specifies in `module_names` was not found and
528+ thus could not be checkpointed.
527529 """
528- # Checkpoint these modules:
529- module_names = (
530- "Perceiver3DEncoder" ,
531- "Swin3DTransformerBackbone" ,
532- "Basic3DEncoderLayer" ,
533- "Basic3DDecoderLayer" ,
534- "Perceiver3DDecoder" ,
535- "LinearPatchReconstruction" ,
536- )
530+
531+ found : set [str ] = set ()
537532
538533 def check (x : torch .nn .Module ) -> bool :
539534 name = x .__class__ .__name__
540- return name in module_names
535+ if name in module_names :
536+ found .add (name )
537+ return True
538+ else :
539+ return False
541540
542541 apply_activation_checkpointing (self , check_fn = check )
543542
543+ if found != set (module_names ):
544+ raise RuntimeError (
545+ f'Could not checkpoint on the following modules: '
546+ f'{ ", " .join (sorted (set (module_names ) - found ))} .'
547+ )
548+
544549
545550class AuroraPretrained (Aurora ):
546551 """Pretrained version of Aurora."""
0 commit comments