Skip to content

Commit 4da21a5

Browse files
committed
adjust based on comments
1 parent fea1666 commit 4da21a5

File tree

3 files changed

+4
-18
lines changed

3 files changed

+4
-18
lines changed

trinity/common/workflows/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,7 @@
4646
"async_simple_mm_workflow": "trinity.common.workflows.simple_mm_workflow.AsyncSimpleMMWorkflow",
4747
# on-policy distillation workflows
4848
"on_policy_distill_workflow": "trinity.common.workflows.on_policy_distill_workflow.OnPolicyDistillWorkflow",
49-
"async_on_policy_distill_workflow": "trinity.common.workflows.on_policy_distill_workflow.AsyncOnPolicyDistillWorkflow",
5049
"on_policy_distill_math_workflow": "trinity.common.workflows.on_policy_distill_workflow.OnPolicyDistillMathWorkflow",
51-
"async_on_policy_distill_math_workflow": "trinity.common.workflows.on_policy_distill_workflow.AsyncOnPolicyDistillMathWorkflow",
5250
},
5351
)
5452

trinity/common/workflows/on_policy_distill_workflow.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -144,12 +144,6 @@ async def run_async(self) -> List[Experience]:
144144
return responses
145145

146146

147-
class AsyncOnPolicyDistillWorkflow(OnPolicyDistillWorkflow):
148-
"""Alias for OnPolicyDistillWorkflow (already async)."""
149-
150-
pass
151-
152-
153147
class OnPolicyDistillMathWorkflow(OnPolicyDistillWorkflow):
154148
"""On-policy distillation workflow with Qwen2.5-Math style format.
155149
@@ -187,9 +181,3 @@ def compute_reward(self, response: Experience) -> float:
187181
response.metrics["accuracy"] = accuracy
188182
return float(accuracy)
189183
return 0.0
190-
191-
192-
class AsyncOnPolicyDistillMathWorkflow(OnPolicyDistillMathWorkflow):
193-
"""Alias for OnPolicyDistillMathWorkflow (already async)."""
194-
195-
pass

trinity/common/workflows/workflow.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,16 @@
44
from __future__ import annotations
55

66
from dataclasses import asdict, dataclass, field
7-
from typing import List, Optional, Type, Union
8-
9-
import openai
7+
from typing import TYPE_CHECKING, Any, List, Optional, Type, Union
108

119
from trinity.common.config import FormatConfig, GenerationConfig
1210
from trinity.common.experience import Experience
13-
from trinity.common.models.model import ModelWrapper
1411
from trinity.common.rewards.reward_fn import RewardFn
1512
from trinity.utils.log import get_logger
1613

14+
if TYPE_CHECKING:
15+
import openai
16+
from trinity.common.models.model import ModelWrapper
1717

1818
@dataclass
1919
class Task(dict):

0 commit comments

Comments
 (0)