@@ -61,11 +61,13 @@ def forward(
61
61
self ,
62
62
input_ids : torch .Tensor ,
63
63
position_ids : torch .Tensor ,
64
+ inputs_embeds : Optional [torch .Tensor ] = None ,
64
65
) -> torch .Tensor :
65
-
66
66
token_type_ids = _decode_token_type_ids (input_ids )
67
67
68
- inputs_embeds = self .word_embeddings (input_ids )
68
+ if inputs_embeds is None :
69
+ inputs_embeds = self .word_embeddings (input_ids )
70
+
69
71
position_embeddings = self .position_embeddings (position_ids )
70
72
71
73
token_type_embeddings = self .token_type_embeddings (token_type_ids )
@@ -358,11 +360,12 @@ def forward(
358
360
intermediate_tensors : Optional [IntermediateTensors ] = None ,
359
361
inputs_embeds : Optional [torch .Tensor ] = None ,
360
362
) -> torch .Tensor :
361
- if inputs_embeds is not None :
362
- hidden_states = inputs_embeds
363
- else :
364
- hidden_states = self .embeddings (input_ids = input_ids ,
365
- position_ids = positions )
363
+ hidden_states = self .embeddings (
364
+ input_ids = input_ids ,
365
+ position_ids = positions ,
366
+ inputs_embeds = inputs_embeds ,
367
+ )
368
+
366
369
return self .encoder (hidden_states )
367
370
368
371
def _load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]):
0 commit comments