Skip to content

Commit a06f59a

Browse files
authored
Fix text classification for community (#288)
1 parent b4be074 commit a06f59a

File tree

4 files changed

+16
-8
lines changed

4 files changed

+16
-8
lines changed

api-inference-community/docker_images/common/app/pipelines/text_classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __call__(self, inputs: str) -> List[Dict[str, float]]:
2222
inputs (:obj:`str`):
2323
a string containing some text
2424
Return:
25-
A :obj:`list`:. The object returned should be like [{"label": 0.9939950108528137}] containing :
25+
A :obj:`list`:. The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing:
2626
- "label": A string representing what the label/class is. There can be multiple labels.
2727
- "score": A score between 0 and 1 describing how confident the model is for this label/class.
2828
"""

api-inference-community/docker_images/common/tests/test_api_text_classification.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,10 @@ def test_simple(self):
4949
)
5050
content = json.loads(response.content)
5151
self.assertEqual(type(content), list)
52+
self.assertEqual(len(content), 1)
53+
self.assertEqual(type(content[0]), list)
5254
self.assertEqual(
53-
set(k for el in content for k in el.keys()),
55+
set(k for el in content[0] for k in el.keys()),
5456
{"label", "score"},
5557
)
5658

@@ -63,8 +65,10 @@ def test_simple(self):
6365
)
6466
content = json.loads(response.content)
6567
self.assertEqual(type(content), list)
68+
self.assertEqual(len(content), 1)
69+
self.assertEqual(type(content[0]), list)
6670
self.assertEqual(
67-
set(k for el in content for k in el.keys()),
71+
set(k for el in content[0] for k in el.keys()),
6872
{"label", "score"},
6973
)
7074

api-inference-community/docker_images/spacy/app/pipelines/text_classification.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@ def __init__(
2828

2929
self.model = spacy.load(model_name)
3030

31-
def __call__(self, inputs: str) -> List[Dict[str, float]]:
31+
def __call__(self, inputs: str) -> List[List[Dict[str, float]]]:
3232
"""
3333
Args:
3434
inputs (:obj:`str`):
3535
a string containing some text
3636
Return:
37-
A :obj:`list`:. The object returned should be like [{"label": 0.9939950108528137}] containing :
37+
A :obj:`list`:. The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing :
3838
- "label": A string representing what the label/class is. There can be multiple labels.
3939
- "score": A score between 0 and 1 describing how confident the model is for this label/class.
4040
"""
@@ -44,4 +44,4 @@ def __call__(self, inputs: str) -> List[Dict[str, float]]:
4444
for cat, score in doc.cats.items():
4545
categories.append({"label": cat, "score": score})
4646

47-
return categories
47+
return [categories]

api-inference-community/docker_images/spacy/tests/test_api_text_classification.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,10 @@ def test_simple(self):
4949
)
5050
content = json.loads(response.content)
5151
self.assertEqual(type(content), list)
52+
self.assertEqual(len(content), 1)
53+
self.assertEqual(type(content[0]), list)
5254
self.assertEqual(
53-
set(k for el in content for k in el.keys()),
55+
set(k for el in content[0] for k in el.keys()),
5456
{"label", "score"},
5557
)
5658

@@ -63,8 +65,10 @@ def test_simple(self):
6365
)
6466
content = json.loads(response.content)
6567
self.assertEqual(type(content), list)
68+
self.assertEqual(len(content), 1)
69+
self.assertEqual(type(content[0]), list)
6670
self.assertEqual(
67-
set(k for el in content for k in el.keys()),
71+
set(k for el in content[0] for k in el.keys()),
6872
{"label", "score"},
6973
)
7074

0 commit comments

Comments
 (0)