Skip to content

Issues when calling model.predict(input) #1591

@Veronicacz

Description

@Veronicacz

Describe the bug
When attempting to call model.predict(input), a ValueError pops up that I haven't encountered before with SimpleTransformers.

To Reproduce
transformers==4.55.4, simpletransformers==0.70.5

Sample code:

model_args = MultiLabelClassificationArgs(
    num_train_epochs=N_EPOCHS,
    overwrite_output_dir=True,
    train_batch_size=32,
    output_dir=args.output_path,
    use_multiprocessing=False,
    use_multiprocessing_for_evaluation=False,
    save_steps = -1, # set -1 to disable
    save_eval_checkpoints = False, 
    save_model_every_epoch = False
)

logger.info('model args: ')
logger.info(model_args)

model = MultiLabelClassificationModel(
    'roberta',
    'distilroberta-base',
    use_cuda=True,
    args=model_args,
    num_labels=1
)

# Train the model
model.train_model(train)

test_data = list(test['Text'])
logger.info(f'Test len: {test.shape}')   # Test len: (66, 2)

# Evaluate on test set
predictions, raw_outputs = model.predict(test_data)

Screenshots

Predicting: 0%| | 0/1 [00:00<?, ?it/s]
Traceback (most recent call last):
File "train.py", line 135, in
predictions, raw_outputs = model.predict(test_data)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/lib/python3.12/site-packages/simpletransformers/classification/multi_label_classification_model.py", line 406, in predict
return super().predict(to_predict, multi_label=multi_label)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/lib/python3.12/site-packages/simpletransformers/classification/classification_model.py", line 2295, in predict
out_label_ids[start_index:end_index] = (
~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^
ValueError: could not broadcast input array from shape (66,) into shape (66,1)

Desktop (please complete the following information):

  • Ubuntu

Additional context
My dataset has two columns: text and labels

@ThilinaRajapakse Thank you so much for fixing the previous issues, I really appreciate it!

I ran into this new issue when trying to make predictions with the trained model. I’m not sure if it’s something specific to my environment, but I didn’t encounter this before, and I haven’t changed any code or packages aside from upgrading these two (transformers==4.55.4, simpletransformers==0.70.5).

Happy to share more details if needed!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions