Skip to content

Commit 970561a

Browse files
xuangu-fangyou-n-g
andauthored
feat: prob-based trace scheduler (#1131)
* draft prob-based trace scheduler * refactor ProbabilisticScheduler * auto lint * keep random, Sota-based, length-based trace schedluer * lint * example * example * refactor * add inverse option for sota Scheduler * add trace_Scheduler in conf * lint * add scheduler_temperature --------- Co-authored-by: Young <[email protected]>
1 parent c36bfd8 commit 970561a

File tree

3 files changed

+249
-29
lines changed

3 files changed

+249
-29
lines changed

rdagent/app/data_science/conf.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class DataScienceBasePropSetting(KaggleBasePropSetting):
2020

2121
planner: str = "rdagent.scenarios.data_science.proposal.exp_gen.planner.DSExpPlannerHandCraft"
2222
hypothesis_gen: str = "rdagent.scenarios.data_science.proposal.exp_gen.router.ParallelMultiTraceExpGen"
23+
trace_scheduler: str = "rdagent.scenarios.data_science.proposal.exp_gen.trace_scheduler.RoundRobinScheduler"
2324
"""Hypothesis generation class"""
2425

2526
summarizer: str = "rdagent.scenarios.data_science.dev.feedback.DSExperiment2Feedback"
@@ -91,6 +92,9 @@ class DataScienceBasePropSetting(KaggleBasePropSetting):
9192
max_trace_num: int = 3
9293
"""The maximum number of traces to grow before merging"""
9394

95+
scheduler_temperature: float = 1.0
96+
"""The temperature for the trace scheduler for softmax calculation, used in ProbabilisticScheduler"""
97+
9498
#### multi-trace:checkpoint selector
9599
selector_name: str = "rdagent.scenarios.data_science.proposal.exp_gen.select.expand.LatestCKPSelector"
96100
"""The name of the selector to use"""

rdagent/scenarios/data_science/proposal/exp_gen/router/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from rdagent.scenarios.data_science.proposal.exp_gen.proposal import DSProposalV2ExpGen
2323
from rdagent.scenarios.data_science.proposal.exp_gen.trace_scheduler import (
2424
RoundRobinScheduler,
25+
SOTABasedScheduler,
2526
TraceScheduler,
2627
)
2728

@@ -46,7 +47,11 @@ def __init__(self, *args, **kwargs):
4647
self.exp_gen = DataScienceRDLoop.default_exp_gen(self.scen)
4748
self.draft_exp_gen = DSDraftV2ExpGen(self.scen)
4849
self.merge_exp_gen = ExpGen2Hypothesis(self.scen)
49-
self.trace_scheduler: TraceScheduler = RoundRobinScheduler(DS_RD_SETTING.max_trace_num)
50+
# self.trace_scheduler: TraceScheduler = RoundRobinScheduler(DS_RD_SETTING.max_trace_num)
51+
self.trace_scheduler: TraceScheduler = import_class(DS_RD_SETTING.trace_scheduler)(
52+
DS_RD_SETTING.max_trace_num,
53+
DS_RD_SETTING.scheduler_temperature,
54+
)
5055
self.planner = import_class(DS_RD_SETTING.planner)(self.scen)
5156

5257
def gen(
Lines changed: 239 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import math
5+
import random
46
from abc import ABC, abstractmethod
57
from collections import defaultdict
68
from typing import TYPE_CHECKING
79

10+
from rdagent.log import rdagent_logger as logger
11+
812
if TYPE_CHECKING:
913
from rdagent.scenarios.data_science.proposal.exp_gen.base import DSTrace
1014

@@ -22,7 +26,7 @@ async def next(self, trace: DSTrace) -> tuple[int, ...]:
2226
2327
For proposing selections, we have to follow the rules
2428
- Suggest selection: suggest a selection that is suitable for the current trace.
25-
- Suggested should be garenteed to be recorded at last!!!
29+
- Suggested should be garenteed to be recorded at last!!!!
2630
- If no suitable selection is found, the function should async wait!!!!
2731
2832
Args:
@@ -35,17 +39,8 @@ async def next(self, trace: DSTrace) -> tuple[int, ...]:
3539
raise NotImplementedError
3640

3741

38-
class RoundRobinScheduler(TraceScheduler):
39-
"""
40-
A concurrency-safe scheduling strategy that cycles through active traces
41-
in a round-robin fashion.
42-
43-
NOTE: we don't need to use asyncio.Lock here as the kickoff_loop ensures the ExpGen is always sequential, instead of parallel.
44-
"""
45-
46-
def __init__(self, max_trace_num: int):
47-
self.max_trace_num = max_trace_num
48-
self._last_selected_leaf_id = -1
42+
class BaseScheduler(TraceScheduler):
43+
def __init__(self):
4944
self.rec_commit_idx = 0 # the node before rec_idx is already committed.
5045
self.uncommited_rec_status = defaultdict(int) # the uncommited record status
5146

@@ -56,25 +51,241 @@ async def next(self, trace: DSTrace) -> tuple[int, ...]:
5651
while True:
5752
# step 0: Commit the pending selections
5853
for i in range(self.rec_commit_idx, len(trace.dag_parent)):
59-
60-
if trace.dag_parent[i] == trace.NEW_ROOT:
54+
parent_of_i = trace.dag_parent[i]
55+
if parent_of_i == trace.NEW_ROOT:
6156
self.uncommited_rec_status[trace.NEW_ROOT] -= 1
6257
else:
63-
for p in trace.dag_parent[i]:
58+
for p in parent_of_i:
6459
self.uncommited_rec_status[p] -= 1
65-
6660
self.rec_commit_idx = len(trace.hist)
6761

68-
# step 1: select the parant trace to expand
69-
# Policy: if we have fewer traces than our target, start a new one.
70-
if trace.sub_trace_count + self.uncommited_rec_status[trace.NEW_ROOT] < self.max_trace_num:
71-
self.uncommited_rec_status[trace.NEW_ROOT] += 1
72-
return trace.NEW_ROOT
73-
74-
# Step2: suggest a selection to a not expanding leave
75-
leaves = trace.get_leaves()
76-
for leaf in leaves:
77-
if self.uncommited_rec_status[leaf] == 0:
78-
self.uncommited_rec_status[leaf] += 1
79-
return (leaf,)
62+
parents = self.select(trace)
63+
64+
if parents is not None:
65+
if parents == trace.NEW_ROOT:
66+
self.uncommited_rec_status[trace.NEW_ROOT] += 1
67+
else:
68+
for p in parents:
69+
self.uncommited_rec_status[p] += 1
70+
return parents
71+
8072
await asyncio.sleep(1)
73+
74+
@abstractmethod
75+
def select(self, trace: DSTrace) -> tuple[int, ...] | None:
76+
"""Selects the parent nodes for the new experiment, or None if no selection can be made."""
77+
raise NotImplementedError
78+
79+
80+
class RoundRobinScheduler(BaseScheduler):
81+
"""
82+
A concurrency-safe scheduling strategy that cycles through active traces
83+
in a round-robin fashion.
84+
85+
NOTE: we don't need to use asyncio.Lock here as the kickoff_loop ensures the ExpGen is always sequential, instead of parallel.
86+
"""
87+
88+
def __init__(self, max_trace_num: int, *args, **kwargs):
89+
logger.info(f"RoundRobinScheduler: max_trace_num={max_trace_num}")
90+
self.max_trace_num = max_trace_num
91+
self._last_selected_leaf_id = -1
92+
super().__init__()
93+
94+
def select(self, trace: DSTrace) -> tuple[int, ...] | None:
95+
"""
96+
Atomically selects the next leaf node from the trace in order.
97+
If no suitable selection is found, return None.
98+
"""
99+
# Policy: if we have fewer traces than our target, start a new one.
100+
if trace.sub_trace_count + self.uncommited_rec_status[trace.NEW_ROOT] < self.max_trace_num:
101+
return trace.NEW_ROOT
102+
103+
# Step2: suggest a selection to a not expanding leave
104+
leaves = trace.get_leaves()
105+
for leaf in leaves:
106+
if self.uncommited_rec_status[leaf] == 0:
107+
return (leaf,)
108+
109+
return None
110+
111+
112+
# ======================================================================================
113+
# Probabilistic Scheduler and its potential functions
114+
# ======================================================================================
115+
116+
117+
class ProbabilisticScheduler(BaseScheduler):
118+
"""
119+
A concurrency-safe scheduling strategy that samples the next trace to expand
120+
based on a probability distribution derived from a potential function.
121+
"""
122+
123+
def __init__(self, max_trace_num: int, temperature: float = 1.0, *args, **kwargs):
124+
"""
125+
Args:
126+
max_trace_num: The target number of parallel traces.
127+
temperature: Temperature parameter for softmax calculation. Higher values make selection more uniform.
128+
"""
129+
if max_trace_num <= 0:
130+
raise ValueError("max_trace_num must be positive.")
131+
if temperature <= 0:
132+
raise ValueError("temperature must be positive.")
133+
134+
self.max_trace_num = max_trace_num
135+
self.temperature = temperature
136+
super().__init__()
137+
138+
def calculate_potential(self, trace: DSTrace, leaf_id: int) -> float:
139+
"""
140+
Calculate potential score for a given leaf node.
141+
This is the base implementation that provides uniform distribution.
142+
143+
Args:
144+
trace: The DSTrace object containing the full experiment history.
145+
leaf_id: The index of the leaf node to evaluate.
146+
147+
Returns:
148+
float: A potential score. Higher means more likely to be selected.
149+
"""
150+
return 1.0 # Uniform distribution by default
151+
152+
def _softmax_probabilities(self, potentials: list[float]) -> list[float]:
153+
"""
154+
Convert potential scores to probabilities using softmax.
155+
156+
Args:
157+
potentials: List of potential scores.
158+
159+
Returns:
160+
List of probabilities that sum to 1.
161+
"""
162+
if not potentials:
163+
return []
164+
165+
# Apply temperature scaling
166+
scaled_potentials = [p / self.temperature for p in potentials]
167+
168+
# Compute softmax
169+
max_potential = max(scaled_potentials)
170+
exp_potentials = [math.exp(p - max_potential) for p in scaled_potentials]
171+
sum_exp = sum(exp_potentials)
172+
173+
if sum_exp == 0:
174+
# If all potentials are very small, return uniform distribution
175+
return [1.0 / len(potentials)] * len(potentials)
176+
177+
return [exp_p / sum_exp for exp_p in exp_potentials]
178+
179+
def select(self, trace: DSTrace) -> tuple[int, ...] | None:
180+
"""
181+
Selects the next leaf node based on probabilistic sampling.
182+
"""
183+
# Step 1: If we have fewer traces than our target, start a new one.
184+
# This policy prioritizes reaching the desired number of traces.
185+
if trace.sub_trace_count + self.uncommited_rec_status[trace.NEW_ROOT] < self.max_trace_num:
186+
return trace.NEW_ROOT
187+
188+
# Step 2: Probabilistically select a leaf to expand.
189+
leaves = trace.get_leaves()
190+
available_leaves = [leaf for leaf in leaves if self.uncommited_rec_status[leaf] == 0]
191+
192+
if not available_leaves:
193+
return None
194+
195+
# Calculate potential for each available leaf
196+
potentials = [self.calculate_potential(trace, leaf) for leaf in available_leaves]
197+
198+
if any(p < 0 for p in potentials):
199+
raise ValueError("Potential function returned a negative value.")
200+
201+
# Convert potentials to probabilities using softmax
202+
probabilities = self._softmax_probabilities(potentials)
203+
204+
# Select a leaf based on probabilities
205+
selected_leaf = random.choices(available_leaves, weights=probabilities, k=1)[0]
206+
207+
return (selected_leaf,)
208+
209+
210+
class TraceLengthScheduler(ProbabilisticScheduler):
211+
"""
212+
A scheduler that prefers longer traces (more experiments)
213+
-- default: prefer to expand the trace that has more experiments (quicker to get the result).
214+
-- if inverse=True, prefer to expand the trace that has less experiments.
215+
216+
"""
217+
218+
def __init__(self, max_trace_num: int, temperature: float = 1.0, inverse: bool = False, *args, **kwargs):
219+
"""
220+
Args:
221+
max_trace_num: The target number of parallel traces.
222+
temperature: Temperature parameter for softmax calculation.
223+
inverse: If True, shorter traces get higher potential.
224+
"""
225+
logger.info(
226+
f"TraceLengthScheduler: max_trace_num={max_trace_num}, temperature={temperature}, inverse={inverse}"
227+
)
228+
super().__init__(max_trace_num, temperature)
229+
self.inverse = inverse
230+
231+
def calculate_potential(self, trace: DSTrace, leaf_id: int) -> float:
232+
"""
233+
Calculate potential based on the length of the trace leading to the leaf.
234+
"""
235+
# Get the path from root to this leaf using existing method
236+
path = trace.get_parents(leaf_id)
237+
path_len = len(path)
238+
239+
if path_len == 0:
240+
return 1.0
241+
242+
return 1.0 / path_len if self.inverse else float(path_len)
243+
244+
245+
class SOTABasedScheduler(ProbabilisticScheduler):
246+
"""
247+
A scheduler that prefers traces with more SOTA (State of the Art) results.
248+
"""
249+
250+
def __init__(self, max_trace_num: int, temperature: float = 1.0, inverse: bool = False, *args, **kwargs):
251+
"""
252+
Args:
253+
max_trace_num: The target number of parallel traces.
254+
temperature: Temperature parameter for softmax calculation.
255+
inverse: If True, fewer SOTA results get higher potential.
256+
"""
257+
logger.info(f"SOTABasedScheduler: max_trace_num={max_trace_num}, temperature={temperature}, inverse={inverse}")
258+
super().__init__(max_trace_num, temperature)
259+
self.inverse = inverse
260+
261+
def calculate_potential(self, trace: DSTrace, leaf_id: int) -> float:
262+
"""
263+
Calculate potential based on the number of SOTA results in the trace.
264+
"""
265+
# Get the path from root to this leaf
266+
path = trace.get_parents(leaf_id)
267+
sota_count = 0
268+
269+
for node_id in path:
270+
# Check if this experiment was successful (decision=True)
271+
if node_id < len(trace.hist):
272+
exp, feedback = trace.hist[node_id]
273+
if feedback.decision:
274+
sota_count += 1
275+
276+
if self.inverse:
277+
# Add 1 to avoid division by zero and give traces with 0 SOTAs the highest potential.
278+
return 1.0 / (sota_count + 1)
279+
return float(sota_count)
280+
281+
282+
class RandomScheduler(ProbabilisticScheduler):
283+
"""
284+
A scheduler that selects traces randomly with uniform distribution.
285+
"""
286+
287+
def calculate_potential(self, trace: DSTrace, leaf_id: int) -> float:
288+
"""
289+
Return random potential for uniform random selection.
290+
"""
291+
return random.random()

0 commit comments

Comments
 (0)