@@ -364,12 +364,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
364364
365365class HFLanguageRepresentationNetwork (nn .Module ):
366366 def __init__ (self ,
367- model_path : str = 'google-bert/bert-base-uncased' ,
368- embedding_size : int = 768 ,
369- group_size : int = 8 ,
370- norm_type : str = "simnorm" ,
371- # norm_type: str = "layernorm", # TODO: Why does nan appear in the first step of training?
372- tokenizer = None ):
367+ model_path : str = 'google-bert/bert-base-uncased' ,
368+ embedding_size : int = 768 ,
369+ group_size : int = 8 ,
370+ final_norm_option_in_encoder : str = "layernorm" ,
371+ tokenizer = None ):
373372 """
374373 Overview:
375374 This class defines a language representation network that utilizes a pretrained Hugging Face model.
@@ -379,7 +378,7 @@ def __init__(self,
379378 - model_path (str): The path to the pretrained Hugging Face model. Default is 'google-bert/bert-base-uncased'.
380379 - embedding_size (int): The dimension of the output embeddings. Default is 768.
381380 - group_size (int): The group size for SimNorm when using normalization.
382- - norm_type (str): The type of normalization to use ("simnorm" or "layernorm"). Default is "layernorm".
381+ - final_norm_option_in_encoder (str): The type of normalization to use ("simnorm" or "layernorm"). Default is "layernorm".
383382 - tokenizer (Optional): An instance of a tokenizer. If None, the tokenizer will be loaded from the pretrained model.
384383 """
385384 super ().__init__ ()
@@ -389,12 +388,13 @@ def __init__(self,
389388
390389 # In distributed training, only the rank 0 process downloads the model, and other processes load from cache to speed up startup.
391390 if get_rank () == 0 :
392- self .model = AutoModel .from_pretrained (model_path )
391+ self .pretrained_model = AutoModel .from_pretrained (model_path )
392+
393393 if get_world_size () > 1 :
394394 # Wait for rank 0 to finish loading the model.
395395 torch .distributed .barrier ()
396396 if get_rank () != 0 :
397- self .model = AutoModel .from_pretrained (model_path )
397+ self .pretrained_model = AutoModel .from_pretrained (model_path )
398398
399399 if tokenizer is None :
400400 # Only rank 0 downloads the tokenizer, and then other processes load it from cache.
@@ -409,15 +409,15 @@ def __init__(self,
409409
410410 # Set the embedding dimension. A linear projection is added (the dimension remains unchanged here but can be extended for other mappings).
411411 self .embedding_size = embedding_size
412- self .embed_proj_head = nn .Linear (self .model .config .hidden_size , self .embedding_size )
412+ self .embed_proj_head = nn .Linear (self .pretrained_model .config .hidden_size , self .embedding_size )
413413
414- # Select the normalization method based on the norm_type parameter.
415- if norm_type .lower () == "simnorm" :
414+ # # Select the normalization method based on the final_norm_option_in_encoder parameter.
415+ if final_norm_option_in_encoder .lower () == "simnorm" :
416416 self .norm = SimNorm (simnorm_dim = group_size )
417- elif norm_type .lower () == "layernorm" :
417+ elif final_norm_option_in_encoder .lower () == "layernorm" :
418418 self .norm = nn .LayerNorm (embedding_size )
419419 else :
420- raise NotImplementedError (f"Normalization type '{ norm_type } ' is not implemented. "
420+ raise NotImplementedError (f"Normalization type '{ final_norm_option_in_encoder } ' is not implemented. "
421421 f"Choose 'simnorm' or 'layernorm'." )
422422
423423 def forward (self , x : torch .Tensor , no_grad : bool = True ) -> torch .Tensor :
@@ -433,26 +433,27 @@ def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor:
433433 Returns:
434434 - torch.Tensor: The processed language embedding with shape [batch_size, embedding_size].
435435 """
436+
436437 # Construct the attention mask to exclude padding tokens.
437438 attention_mask = x != self .tokenizer .pad_token_id
438439
439440 # Use no_grad context if specified to disable gradient computation.
440441 if no_grad :
441442 with torch .no_grad ():
442443 x = x .long () # Ensure the input tensor is of type long.
443- outputs = self .model (x , attention_mask = attention_mask )
444+ outputs = self .pretrained_model (x , attention_mask = attention_mask )
444445 # Get the hidden state from the last layer and select the output corresponding to the [CLS] token.
445446 cls_embedding = outputs .last_hidden_state [:, 0 , :]
446447 else :
447448 x = x .long ()
448- outputs = self .model (x , attention_mask = attention_mask )
449+ outputs = self .pretrained_model (x , attention_mask = attention_mask )
449450 cls_embedding = outputs .last_hidden_state [:, 0 , :]
450451
451452 # Apply linear projection to obtain the desired output dimension.
452453 cls_embedding = self .embed_proj_head (cls_embedding )
453454 # Normalize the embeddings using the selected normalization layer (SimNorm or LayerNorm) to ensure training stability.
454455 cls_embedding = self .norm (cls_embedding )
455-
456+
456457 return cls_embedding
457458
458459
@@ -468,6 +469,7 @@ def __init__(
468469 norm_type : str = 'BN' ,
469470 embedding_dim : int = 256 ,
470471 group_size : int = 8 ,
472+ final_norm_option_in_encoder : str = 'LayerNorm' , # TODO
471473 ) -> None :
472474 """
473475 Overview:
@@ -486,6 +488,8 @@ def __init__(
486488 - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'.
487489 - embedding_dim (:obj:`int`): The dimension of the latent state.
488490 - group_size (:obj:`int`): The dimension for simplicial normalization.
491+ - final_norm_option_in_encoder (:obj:`str`): The normalization option for the final layer, defaults to 'SimNorm'. \
492+ Options are 'SimNorm' and 'LayerNorm'.
489493 """
490494 super ().__init__ ()
491495 assert norm_type in ['BN' , 'LN' ], "norm_type must in ['BN', 'LN']"
@@ -530,7 +534,14 @@ def __init__(
530534 elif self .observation_shape [1 ] in [84 , 96 ]:
531535 self .last_linear = nn .Linear (64 * 6 * 6 , self .embedding_dim , bias = False )
532536
533- self .sim_norm = SimNorm (simnorm_dim = group_size )
537+ self .final_norm_option_in_encoder = final_norm_option_in_encoder
538+ if self .final_norm_option_in_encoder == 'LayerNorm' :
539+ self .final_norm = nn .LayerNorm (self .embedding_dim , eps = 1e-5 )
540+ elif self .final_norm_option_in_encoder == 'SimNorm' :
541+ self .final_norm = SimNorm (simnorm_dim = group_size )
542+ else :
543+ raise ValueError (f"Unsupported final_norm_option_in_encoder: { self .final_norm_option_in_encoder } " )
544+
534545
535546 def forward (self , x : torch .Tensor ) -> torch .Tensor :
536547 """
@@ -557,7 +568,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
557568 x = x .view (- 1 , self .embedding_dim )
558569
559570 # NOTE: very important for training stability.
560- x = self .sim_norm (x )
571+ x = self .final_norm (x )
561572
562573 return x
563574
@@ -670,6 +681,7 @@ def __init__(
670681 activation : nn .Module = nn .GELU (approximate = 'tanh' ),
671682 norm_type : Optional [str ] = 'BN' ,
672683 group_size : int = 8 ,
684+ final_norm_option_in_encoder : str = 'LayerNorm' , # TODO
673685 ) -> torch .Tensor :
674686 """
675687 Overview:
@@ -700,7 +712,15 @@ def __init__(
700712 # last_linear_layer_init_zero=True is beneficial for convergence speed.
701713 last_linear_layer_init_zero = True ,
702714 )
703- self .sim_norm = SimNorm (simnorm_dim = group_size )
715+
716+ # # Select the normalization method based on the final_norm_option_in_encoder parameter.
717+ if final_norm_option_in_encoder .lower () == "simnorm" :
718+ self .norm = SimNorm (simnorm_dim = group_size )
719+ elif final_norm_option_in_encoder .lower () == "layernorm" :
720+ self .norm = nn .LayerNorm (hidden_channels )
721+ else :
722+ raise NotImplementedError (f"Normalization type '{ final_norm_option_in_encoder } ' is not implemented. "
723+ f"Choose 'simnorm' or 'layernorm'." )
704724
705725 def forward (self , x : torch .Tensor ) -> torch .Tensor :
706726 """
@@ -709,8 +729,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
709729 - output (:obj:`torch.Tensor`): :math:`(B, hidden_channels)`, where B is batch size.
710730 """
711731 x = self .fc_representation (x )
712- # TODO
713- x = self . sim_norm ( x )
732+ x = self . norm ( x )
733+
714734 return x
715735
716736
0 commit comments