Skip to content

Commit 0471177

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 3f056b9 commit 0471177

File tree

12 files changed

+152
-119
lines changed

12 files changed

+152
-119
lines changed

argilla/src/argilla/cli/datasets/__main__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,11 @@ def callback(
4444
dataset = FeedbackDataset.from_argilla(name=name, workspace=workspace)
4545
except ValueError as e:
4646
echo_in_panel(
47-
f"`FeedbackDataset` with name={name} not found in Argilla. Try using '--workspace' option."
48-
if not workspace
49-
else f"`FeedbackDataset with name={name} and workspace={workspace} not found in Argilla.",
47+
(
48+
f"`FeedbackDataset` with name={name} not found in Argilla. Try using '--workspace' option."
49+
if not workspace
50+
else f"`FeedbackDataset with name={name} and workspace={workspace} not found in Argilla."
51+
),
5052
title="Dataset not found",
5153
title_align="left",
5254
success=False,

argilla/src/argilla/client/datasets.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -744,16 +744,20 @@ def _to_datasets_dict(self) -> Dict:
744744
for key in self._RECORD_TYPE.__fields__:
745745
if key == "prediction":
746746
ds_dict[key] = [
747-
[{"label": pred[0], "score": pred[1]} for pred in rec.prediction]
748-
if rec.prediction is not None
749-
else None
747+
(
748+
[{"label": pred[0], "score": pred[1]} for pred in rec.prediction]
749+
if rec.prediction is not None
750+
else None
751+
)
750752
for rec in self._records
751753
]
752754
elif key == "explanation":
753755
ds_dict[key] = [
754-
{key: list(map(dict, tokattrs)) for key, tokattrs in rec.explanation.items()}
755-
if rec.explanation is not None
756-
else None
756+
(
757+
{key: list(map(dict, tokattrs)) for key, tokattrs in rec.explanation.items()}
758+
if rec.explanation is not None
759+
else None
760+
)
757761
for rec in self._records
758762
]
759763
elif key == "id":
@@ -1255,9 +1259,11 @@ def entities_to_dict(
12551259
if entities is None:
12561260
return None
12571261
return [
1258-
{"label": ent[0], "start": ent[1], "end": ent[2]}
1259-
if len(ent) == 3
1260-
else {"label": ent[0], "start": ent[1], "end": ent[2], "score": ent[3]}
1262+
(
1263+
{"label": ent[0], "start": ent[1], "end": ent[2]}
1264+
if len(ent) == 3
1265+
else {"label": ent[0], "start": ent[1], "end": ent[2], "score": ent[3]}
1266+
)
12611267
for ent in entities
12621268
]
12631269

@@ -1281,9 +1287,11 @@ def __entities_to_tuple__(
12811287
entities,
12821288
) -> List[Union[Tuple[str, int, int], Tuple[str, int, int, float]]]:
12831289
return [
1284-
(ent["label"], ent["start"], ent["end"])
1285-
if len(ent) == 3
1286-
else (ent["label"], ent["start"], ent["end"], ent["score"] or 0.0)
1290+
(
1291+
(ent["label"], ent["start"], ent["end"])
1292+
if len(ent) == 3
1293+
else (ent["label"], ent["start"], ent["end"], ent["score"] or 0.0)
1294+
)
12871295
for ent in entities
12881296
]
12891297

argilla/src/argilla/client/feedback/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ class DatasetConfig(BaseModel):
4343
fields: List[AllowedFieldTypes]
4444
questions: List[Annotated[AllowedQuestionTypes, Field(..., discriminator="type")]]
4545
guidelines: Optional[str] = None
46-
metadata_properties: Optional[
47-
List[Annotated[AllowedMetadataPropertyTypes, Field(..., discriminator="type")]]
48-
] = None
46+
metadata_properties: Optional[List[Annotated[AllowedMetadataPropertyTypes, Field(..., discriminator="type")]]] = (
47+
None
48+
)
4949
allow_extra_metadata: bool = True
5050
vectors_settings: Optional[List[VectorSettings]] = None
5151

argilla/src/argilla/client/feedback/dataset/local/mixins.py

Lines changed: 47 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -504,23 +504,27 @@ def for_text_classification(
504504
return cls(
505505
fields=[TextField(name="text", use_markdown=use_markdown)],
506506
questions=[
507-
LabelQuestion(
508-
name="label",
509-
labels=labels,
510-
description=description,
511-
)
512-
if not multi_label
513-
else MultiLabelQuestion(
514-
name="label",
515-
labels=labels,
516-
description=description,
507+
(
508+
LabelQuestion(
509+
name="label",
510+
labels=labels,
511+
description=description,
512+
)
513+
if not multi_label
514+
else MultiLabelQuestion(
515+
name="label",
516+
labels=labels,
517+
description=description,
518+
)
517519
)
518520
],
519-
guidelines=guidelines
520-
if guidelines is not None
521-
else default_guidelines
522-
if multi_label
523-
else default_guidelines.replace("one or more labels", "one label"),
521+
guidelines=(
522+
guidelines
523+
if guidelines is not None
524+
else (
525+
default_guidelines if multi_label else default_guidelines.replace("one or more labels", "one label")
526+
)
527+
),
524528
metadata_properties=metadata_properties,
525529
vectors_settings=vectors_settings,
526530
)
@@ -739,11 +743,15 @@ def for_supervised_fine_tuning(
739743
name="response", description="Write the response to the instruction.", use_markdown=use_markdown
740744
)
741745
],
742-
guidelines=guidelines
743-
if guidelines is not None
744-
else default_guidelines + " Take the context into account when writing the response."
745-
if context
746-
else default_guidelines,
746+
guidelines=(
747+
guidelines
748+
if guidelines is not None
749+
else (
750+
default_guidelines + " Take the context into account when writing the response."
751+
if context
752+
else default_guidelines
753+
)
754+
),
747755
metadata_properties=metadata_properties,
748756
vectors_settings=vectors_settings,
749757
)
@@ -977,23 +985,27 @@ def for_multi_modal_classification(
977985
return cls(
978986
fields=[TextField(name="content", use_markdown=True, required=True)],
979987
questions=[
980-
LabelQuestion(
981-
name="label",
982-
labels=labels,
983-
description=description,
984-
)
985-
if not multi_label
986-
else MultiLabelQuestion(
987-
name="label",
988-
labels=labels,
989-
description=description,
988+
(
989+
LabelQuestion(
990+
name="label",
991+
labels=labels,
992+
description=description,
993+
)
994+
if not multi_label
995+
else MultiLabelQuestion(
996+
name="label",
997+
labels=labels,
998+
description=description,
999+
)
9901000
)
9911001
],
992-
guidelines=guidelines
993-
if guidelines is not None
994-
else default_guidelines
995-
if multi_label
996-
else default_guidelines.replace("one or more labels", "one label"),
1002+
guidelines=(
1003+
guidelines
1004+
if guidelines is not None
1005+
else (
1006+
default_guidelines if multi_label else default_guidelines.replace("one or more labels", "one label")
1007+
)
1008+
),
9971009
metadata_properties=metadata_properties,
9981010
vectors_settings=vectors_settings,
9991011
)

argilla/src/argilla/client/feedback/integrations/textdescriptives.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -308,11 +308,11 @@ def _add_text_descriptives_to_metadata(
308308
filtered_metrics = {key: value for key, value in metrics.items() if not pd.isna(value)}
309309
if metadata_prop_types is not None:
310310
filtered_metrics = {
311-
key: int(value)
312-
if metadata_prop_types.get(key) == "integer"
313-
else float(value)
314-
if metadata_prop_types.get(key) == "float"
315-
else value
311+
key: (
312+
int(value)
313+
if metadata_prop_types.get(key) == "integer"
314+
else float(value) if metadata_prop_types.get(key) == "float" else value
315+
)
316316
for key, value in filtered_metrics.items()
317317
}
318318
record.metadata.update(filtered_metrics)

argilla/src/argilla/client/feedback/schemas/metadata.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,7 @@ def title_must_have_value(cls, v: Optional[str], values: Dict[str, Any]) -> str:
7171

7272
@property
7373
@abstractmethod
74-
def server_settings(self) -> Dict[str, Any]:
75-
...
74+
def server_settings(self) -> Dict[str, Any]: ...
7675

7776
def to_server_payload(self) -> Dict[str, Any]:
7877
return {
@@ -84,20 +83,17 @@ def to_server_payload(self) -> Dict[str, Any]:
8483

8584
@property
8685
@abstractmethod
87-
def _pydantic_field_with_validator(self) -> Tuple[Dict[str, Tuple[Any, ...]], Dict[str, Callable]]:
88-
...
86+
def _pydantic_field_with_validator(self) -> Tuple[Dict[str, Tuple[Any, ...]], Dict[str, Callable]]: ...
8987

9088
@abstractmethod
9189
def _validate_filter(self, metadata_filter: "MetadataFilters") -> None:
9290
pass
9391

9492
@abstractmethod
95-
def _check_allowed_value_type(self, value: Any) -> Any:
96-
...
93+
def _check_allowed_value_type(self, value: Any) -> Any: ...
9794

9895
@abstractmethod
99-
def _validator(self, value: Any) -> Any:
100-
...
96+
def _validator(self, value: Any) -> Any: ...
10197

10298

10399
def _validator_definition(schema: MetadataPropertySchema) -> Dict[str, Any]:
@@ -395,8 +391,7 @@ class Config:
395391

396392
@property
397393
@abstractmethod
398-
def query_string(self) -> str:
399-
...
394+
def query_string(self) -> str: ...
400395

401396

402397
class TermsMetadataFilter(MetadataFilterSchema):

argilla/src/argilla/client/feedback/schemas/remote/records.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,9 @@ def __updated_record_data(self) -> None:
207207

208208
updated_record = self.from_api(
209209
payload=response.parsed,
210-
question_id_to_name={value: key for key, value in self.question_name_to_id.items()}
211-
if self.question_name_to_id
212-
else None,
210+
question_id_to_name=(
211+
{value: key for key, value in self.question_name_to_id.items()} if self.question_name_to_id else None
212+
),
213213
client=self.client,
214214
)
215215

@@ -306,15 +306,17 @@ def from_api(
306306
id=payload.id,
307307
client=client,
308308
fields=payload.fields,
309-
responses=[RemoteResponseSchema.from_api(response) for response in payload.responses]
310-
if payload.responses
311-
else [],
312-
suggestions=[
313-
RemoteSuggestionSchema.from_api(suggestion, question_id_to_name=question_id_to_name, client=client)
314-
for suggestion in payload.suggestions
315-
]
316-
if payload.suggestions
317-
else [],
309+
responses=(
310+
[RemoteResponseSchema.from_api(response) for response in payload.responses] if payload.responses else []
311+
),
312+
suggestions=(
313+
[
314+
RemoteSuggestionSchema.from_api(suggestion, question_id_to_name=question_id_to_name, client=client)
315+
for suggestion in payload.suggestions
316+
]
317+
if payload.suggestions
318+
else []
319+
),
318320
metadata=payload.metadata if payload.metadata else {},
319321
vectors=payload.vectors if payload.vectors else {},
320322
external_id=payload.external_id if payload.external_id else None,

argilla/src/argilla/client/sdk/text2text/models.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,11 @@ def from_client(cls, record: ClientText2TextRecord):
4848
if record.prediction is not None:
4949
prediction = Text2TextAnnotation(
5050
sentences=[
51-
Text2TextPrediction(text=pred[0], score=pred[1])
52-
if isinstance(pred, tuple)
53-
else Text2TextPrediction(text=pred)
51+
(
52+
Text2TextPrediction(text=pred[0], score=pred[1])
53+
if isinstance(pred, tuple)
54+
else Text2TextPrediction(text=pred)
55+
)
5456
for pred in record.prediction
5557
],
5658
agent=record.prediction_agent or MACHINE_NAME,
@@ -81,9 +83,9 @@ class Text2TextRecord(CreationText2TextRecord):
8183
def to_client(self) -> ClientText2TextRecord:
8284
return ClientText2TextRecord(
8385
text=self.text,
84-
prediction=[(sentence.text, sentence.score) for sentence in self.prediction.sentences]
85-
if self.prediction
86-
else None,
86+
prediction=(
87+
[(sentence.text, sentence.score) for sentence in self.prediction.sentences] if self.prediction else None
88+
),
8789
prediction_agent=self.prediction.agent if self.prediction else None,
8890
annotation=self.annotation.sentences[0].text if self.annotation else None,
8991
annotation_agent=self.annotation.agent if self.annotation else None,

argilla/src/argilla/client/sdk/text_classification/models.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -100,19 +100,21 @@ def to_client(self) -> ClientTextClassificationRecord:
100100
multi_label=self.multi_label,
101101
status=self.status,
102102
metadata=self.metadata or {},
103-
prediction=[(label.class_label, label.score) for label in self.prediction.labels]
104-
if self.prediction
105-
else None,
103+
prediction=(
104+
[(label.class_label, label.score) for label in self.prediction.labels] if self.prediction else None
105+
),
106106
prediction_agent=self.prediction.agent if self.prediction else None,
107107
annotation=annotations,
108108
annotation_agent=self.annotation.agent if self.annotation else None,
109109
vectors=self._to_client_vectors(self.vectors),
110-
explanation={
111-
key: [ClientTokenAttributions.parse_obj(attribution) for attribution in attributions]
112-
for key, attributions in self.explanation.items()
113-
}
114-
if self.explanation
115-
else None,
110+
explanation=(
111+
{
112+
key: [ClientTokenAttributions.parse_obj(attribution) for attribution in attributions]
113+
for key, attributions in self.explanation.items()
114+
}
115+
if self.explanation
116+
else None
117+
),
116118
metrics=self.metrics or None,
117119
search_keywords=self.search_keywords or None,
118120
)

argilla/src/argilla/client/sdk/token_classification/models.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,11 @@ def from_client(cls, record: ClientTokenClassificationRecord):
5858
if record.prediction is not None:
5959
prediction = TokenClassificationAnnotation(
6060
entities=[
61-
EntitySpan(label=ent[0], start=ent[1], end=ent[2])
62-
if len(ent) == 3
63-
else EntitySpan(label=ent[0], start=ent[1], end=ent[2], score=ent[3])
61+
(
62+
EntitySpan(label=ent[0], start=ent[1], end=ent[2])
63+
if len(ent) == 3
64+
else EntitySpan(label=ent[0], start=ent[1], end=ent[2], score=ent[3])
65+
)
6466
for ent in record.prediction
6567
],
6668
agent=record.prediction_agent or MACHINE_NAME,
@@ -94,13 +96,15 @@ def to_client(self) -> ClientTokenClassificationRecord:
9496
return ClientTokenClassificationRecord(
9597
text=self.text,
9698
tokens=self.tokens,
97-
prediction=[(ent.label, ent.start, ent.end, ent.score) for ent in self.prediction.entities]
98-
if self.prediction
99-
else None,
99+
prediction=(
100+
[(ent.label, ent.start, ent.end, ent.score) for ent in self.prediction.entities]
101+
if self.prediction
102+
else None
103+
),
100104
prediction_agent=self.prediction.agent if self.prediction else None,
101-
annotation=[(ent.label, ent.start, ent.end) for ent in self.annotation.entities]
102-
if self.annotation
103-
else None,
105+
annotation=(
106+
[(ent.label, ent.start, ent.end) for ent in self.annotation.entities] if self.annotation else None
107+
),
104108
annotation_agent=self.annotation.agent if self.annotation else None,
105109
vectors=self._to_client_vectors(self.vectors),
106110
id=self.id,

0 commit comments

Comments
 (0)