Skip to content

Commit 2d8c8ed

Browse files
authored
feat: Dreadnode SDK Integration (#203)
* wip * Finalizing migration to the sdk with tasks used everywhere, model annotations, task labels, and logging. * Migrate completion code. Small bug fix for http transform api_key * Seperate caching mechanics * Add parallel mode for pipeline run_many and run_batch. Integrate with sdk scorers. * Port tools to base models * Add filters to pipeline score * Fix linting errors * Dependency and lint fixes * Fixing tests
1 parent a5488c0 commit 2d8c8ed

34 files changed

+2180
-692
lines changed

docs/api/chat.mdx

Lines changed: 400 additions & 103 deletions
Large diffs are not rendered by default.

docs/api/completion.mdx

Lines changed: 58 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -897,12 +897,21 @@ async def run(
897897
on_failed = on_failed or self.on_failed
898898
states = self._initialize_states(1)
899899

900-
with tracer.span(
901-
f"Completion with {self.generator.to_identifier()}",
902-
generator_id=self.generator.to_identifier(),
903-
params=self.params.to_dict() if self.params is not None else {},
904-
) as span:
905-
return (await self._run(span, states, on_failed))[0]
900+
with dn.task_span(
901+
f"pipeline - {self.generator.to_identifier(short=True)}",
902+
label=f"pipeline_{self.generator.to_identifier(short=True)}",
903+
attributes={"rigging.type": "completion_pipeline.run"},
904+
) as task:
905+
dn.log_inputs(
906+
text=self.text,
907+
params=self.params.to_dict() if self.params is not None else {},
908+
generator_id=self.generator.to_identifier(),
909+
)
910+
completions = await self._run(task, states, on_failed)
911+
completion = completions[0]
912+
dn.log_output("completion", completion)
913+
task.set_attribute("completions", completions)
914+
return completion
906915
```
907916

908917

@@ -978,13 +987,21 @@ async def run_batch(
978987
for state in states:
979988
next(state.processor)
980989

981-
with tracer.span(
982-
f"Completion batch with {self.generator.to_identifier()} ({len(states)})",
983-
count=len(states),
984-
generator_id=self.generator.to_identifier(),
985-
params=self.params.to_dict() if self.params is not None else {},
986-
) as span:
987-
return await self._run(span, states, on_failed, batch_mode=True)
990+
with dn.task_span(
991+
f"pipeline - {self.generator.to_identifier(short=True)} (batch x{len(states)})",
992+
label=f"pipeline_batch_{self.generator.to_identifier(short=True)}",
993+
attributes={"rigging.type": "completion_pipeline.run_batch"},
994+
) as task:
995+
dn.log_inputs(
996+
count=len(states),
997+
many=many,
998+
params=params,
999+
generator_id=self.generator.to_identifier(),
1000+
)
1001+
completions = await self._run(task, states, on_failed, batch_mode=True)
1002+
dn.log_output("completions", completions)
1003+
task.set_attribute("completions", completions)
1004+
return completions
9881005
```
9891006

9901007

@@ -1047,13 +1064,21 @@ async def run_many(
10471064
on_failed = on_failed or self.on_failed
10481065
states = self._initialize_states(count, params)
10491066

1050-
with tracer.span(
1051-
f"Completion with {self.generator.to_identifier()} (x{count})",
1052-
count=count,
1053-
generator_id=self.generator.to_identifier(),
1054-
params=self.params.to_dict() if self.params is not None else {},
1055-
) as span:
1056-
return await self._run(span, states, on_failed)
1067+
with dn.task_span(
1068+
f"pipeline - {self.generator.to_identifier(short=True)} (x{count})",
1069+
label=f"pipeline_many_{self.generator.to_identifier(short=True)}",
1070+
attributes={"rigging.type": "completion_pipeline.run_many"},
1071+
) as task:
1072+
dn.log_inputs(
1073+
count=count,
1074+
text=self.text,
1075+
params=self.params.to_dict() if self.params is not None else {},
1076+
generator_id=self.generator.to_identifier(),
1077+
)
1078+
completions = await self._run(task, states, on_failed)
1079+
dn.log_output("completions", completions)
1080+
task.set_attribute("completions", completions)
1081+
return completions
10571082
```
10581083

10591084

@@ -1133,9 +1158,20 @@ async def run_over(
11331158
sub.generator = generator
11341159
coros.append(sub.run(allow_failed=(on_failed != "raise")))
11351160

1136-
with tracer.span(f"Completion over {len(coros)} generators", count=len(coros)):
1161+
short_generators = [g.to_identifier(short=True) for g in _generators]
1162+
task_name = "iterate - " + ", ".join(short_generators)
1163+
1164+
with dn.task_span(
1165+
task_name,
1166+
label="iterate_over",
1167+
attributes={"rigging.type": "completion_pipeline.run_over"},
1168+
) as task:
1169+
dn.log_input("generators", [g.to_identifier() for g in _generators])
11371170
completions = await asyncio.gather(*coros)
1138-
return await self._post_run(completions, on_failed)
1171+
final_completions = await self._post_run(completions, on_failed)
1172+
dn.log_output("completions", final_completions)
1173+
task.set_attribute("completions", final_completions)
1174+
return final_completions
11391175
```
11401176

11411177

0 commit comments

Comments
 (0)