File tree Expand file tree Collapse file tree 1 file changed +5
-0
lines changed Expand file tree Collapse file tree 1 file changed +5
-0
lines changed Original file line number Diff line number Diff line change @@ -238,6 +238,7 @@ def _from_pretrained(
238238 model = TasksManager .get_model_from_task (
239239 task ,
240240 model_id ,
241+ library_name = "transformers" ,
241242 trust_remote_code = trust_remote_code ,
242243 torch_dtype = torch_dtype ,
243244 _commit_hash = commit_hash ,
@@ -273,6 +274,7 @@ def forward(
273274 input_ids : torch .Tensor ,
274275 attention_mask : torch .Tensor ,
275276 token_type_ids : torch .Tensor = None ,
277+ position_ids : torch .Tensor = None ,
276278 ** kwargs ,
277279 ):
278280 inputs = {
@@ -283,6 +285,9 @@ def forward(
283285 if "token_type_ids" in self .input_names :
284286 inputs ["token_type_ids" ] = token_type_ids
285287
288+ if "position_ids" in self .input_names :
289+ inputs ["position_ids" ] = position_ids
290+
286291 outputs = self ._call_model (** inputs )
287292 if isinstance (outputs , dict ):
288293 model_output = ModelOutput (** outputs )
You can’t perform that action at this time.
0 commit comments