|
30 | 30 | class PromptModelTest(unittest.TestCase):
|
31 | 31 | @classmethod
|
32 | 32 | def setUpClass(cls):
|
33 |
| - cls.tokenizer = AutoTokenizer.from_pretrained("__internal_testing__/ernie") |
34 |
| - cls.model = AutoModelForMaskedLM.from_pretrained("__internal_testing__/ernie") |
| 33 | + cls.tokenizer = AutoTokenizer.from_pretrained("__internal_testing__/tiny-random-ernie") |
| 34 | + cls.model = AutoModelForMaskedLM.from_pretrained("__internal_testing__/tiny-random-ernie") |
| 35 | + cls.num_labels = 2 |
| 36 | + cls.seq_cls_model = AutoModelForSequenceClassification.from_pretrained( |
| 37 | + "__internal_testing__/tiny-random-ernie", num_labels=cls.num_labels |
| 38 | + ) |
35 | 39 |
|
36 | 40 | cls.template = AutoTemplate.create_from(
|
37 | 41 | prompt="{'soft'}{'text': 'text'}{'mask'}", tokenizer=cls.tokenizer, max_length=512, model=cls.model
|
@@ -71,36 +75,32 @@ def test_sequence_classification_with_labels(self):
|
71 | 75 | self.assertEqual(model_outputs.hidden_states.shape[0], len(examples))
|
72 | 76 |
|
73 | 77 | def test_efl_no_labels(self):
|
74 |
| - num_labels = 2 |
75 |
| - model = AutoModelForSequenceClassification.from_pretrained("__internal_testing__/ernie", num_labels=num_labels) |
76 |
| - prompt_model = PromptModelForSequenceClassification(model, self.template, verbalizer=None) |
| 78 | + prompt_model = PromptModelForSequenceClassification(self.seq_cls_model, self.template, verbalizer=None) |
77 | 79 | examples = [{"text": "百度飞桨深度学习框架"}, {"text": "这是一个测试"}]
|
78 | 80 | encoded_examples = [self.template(i) for i in examples]
|
79 | 81 | logits, hidden_states = prompt_model(**self.data_collator(encoded_examples))
|
80 | 82 | self.assertEqual(logits.shape[0], len(examples))
|
81 |
| - self.assertEqual(logits.shape[1], num_labels) |
| 83 | + self.assertEqual(logits.shape[1], self.num_labels) |
82 | 84 | self.assertEqual(hidden_states.shape[0], len(examples))
|
83 | 85 |
|
84 | 86 | model_outputs = prompt_model(**self.data_collator(encoded_examples), return_dict=True)
|
85 | 87 | self.assertIsNone(model_outputs.loss)
|
86 | 88 | self.assertEqual(model_outputs.logits.shape[0], len(examples))
|
87 |
| - self.assertEqual(model_outputs.logits.shape[1], num_labels) |
| 89 | + self.assertEqual(model_outputs.logits.shape[1], self.num_labels) |
88 | 90 | self.assertEqual(model_outputs.hidden_states.shape[0], len(examples))
|
89 | 91 |
|
90 | 92 | def test_efl_with_labels(self):
|
91 |
| - num_labels = 2 |
92 |
| - model = AutoModelForSequenceClassification.from_pretrained("__internal_testing__/ernie", num_labels=num_labels) |
93 |
| - prompt_model = PromptModelForSequenceClassification(model, self.template, verbalizer=None) |
| 93 | + prompt_model = PromptModelForSequenceClassification(self.seq_cls_model, self.template, verbalizer=None) |
94 | 94 | examples = [{"text": "百度飞桨深度学习框架", "labels": 0}, {"text": "这是一个测试", "labels": 1}]
|
95 | 95 | encoded_examples = [self.template(i) for i in examples]
|
96 | 96 | loss, logits, hidden_states = prompt_model(**self.data_collator(encoded_examples))
|
97 | 97 | self.assertIsNotNone(loss)
|
98 | 98 | self.assertEqual(logits.shape[0], len(examples))
|
99 |
| - self.assertEqual(logits.shape[1], num_labels) |
| 99 | + self.assertEqual(logits.shape[1], self.num_labels) |
100 | 100 | self.assertEqual(hidden_states.shape[0], len(examples))
|
101 | 101 |
|
102 | 102 | model_outputs = prompt_model(**self.data_collator(encoded_examples), return_dict=True)
|
103 | 103 | self.assertIsNotNone(model_outputs.loss)
|
104 | 104 | self.assertEqual(model_outputs.logits.shape[0], len(examples))
|
105 |
| - self.assertEqual(model_outputs.logits.shape[1], num_labels) |
| 105 | + self.assertEqual(model_outputs.logits.shape[1], self.num_labels) |
106 | 106 | self.assertEqual(model_outputs.hidden_states.shape[0], len(examples))
|
0 commit comments