Skip to content

Commit 2e50041

Browse files
committed
fix data loader
1 parent d7b96ce commit 2e50041

File tree

4 files changed

+106
-101
lines changed

4 files changed

+106
-101
lines changed

tracestorm/cli.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,9 @@ def main(
133133

134134
if datasets_config_file is None:
135135
datasets = []
136-
sort = None
136+
sort_strategy = None
137137
else:
138-
datasets, sort = load_datasets(datasets_config_file)
138+
datasets, sort_strategy = load_datasets(datasets_config_file)
139139

140140
_, result_analyzer = run_load_test(
141141
trace_generator=trace_generator,
@@ -144,7 +144,7 @@ def main(
144144
base_url=base_url,
145145
api_key=api_key,
146146
datasets=datasets,
147-
sort=sort,
147+
sort_strategy=sort_strategy,
148148
seed=seed,
149149
)
150150

tracestorm/core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def run_load_test(
1818
base_url: str,
1919
api_key: str,
2020
datasets: List,
21-
sort: Optional[str] = None,
21+
sort_strategy: Optional[str] = None,
2222
seed: Optional[int] = None,
2323
) -> Tuple[List[Tuple], ResultAnalyzer]:
2424
"""
@@ -31,7 +31,7 @@ def run_load_test(
3131
base_url: Base URL for API calls
3232
api_key: API key for authentication
3333
datasets: List of datasets to generate prompts
34-
sort: Sorting strategy for prompts in datasets.
34+
sort_strategy: Sorting strategy for prompts in datasets.
3535
seed: Random seed for sorting.
3636
3737
Returns:
@@ -48,7 +48,7 @@ def run_load_test(
4848
model_name=model,
4949
nums=total_requests,
5050
datasets=datasets,
51-
sort=sort,
51+
sort_strategy=sort_strategy,
5252
seed=seed,
5353
)
5454
ipc_queue = multiprocessing.Queue()

tracestorm/data_loader.py

Lines changed: 96 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ def normalize_prompts(row) -> List[str]:
4444
"",
4545
)
4646
prompts.append(prompt)
47+
else: # we cannot handle this type
48+
continue
4749
elif isinstance(row, str): # if the row is already a prompt
4850
prompts.append(row)
4951
elif (
@@ -78,109 +80,112 @@ def load_datasets(
7880
Return:
7981
(List[Dataset], str): A list of Dataset objects and the sorting strategy.
8082
"""
83+
if datasets_config_file is None:
84+
logger.error("Customized data loading logic needs to be implemented!")
85+
return [], None
86+
8187
# Load datasets configuration file
82-
if datasets_config_file:
83-
try:
84-
with open(datasets_config_file, "r") as f:
85-
datasets_config = json.load(f)
86-
except FileNotFoundError:
87-
logger.error(
88-
f"Configuration file '{datasets_config_file}' not found"
89-
)
90-
return [], None
91-
except Exception as e:
92-
logger.error(f"Error reading '{datasets_config_file}': {e}")
93-
return [], None
94-
95-
# Strategy to sort the provided datasets
96-
sort_strategy = datasets_config.pop("sort", "random")
97-
98-
# List to store each Dataset
99-
datasets = []
88+
try:
89+
with open(datasets_config_file, "r") as f:
90+
datasets_config = json.load(f)
91+
except FileNotFoundError:
92+
logger.error(
93+
f"Configuration file '{datasets_config_file}' not found"
94+
)
95+
return [], None
96+
except Exception as e:
97+
logger.error(f"Error reading '{datasets_config_file}': {e}")
98+
return [], None
10099

101-
for name, config in datasets_config.items():
102-
file_name = config.get("file_name")
103-
prompt_field = config.get("prompt_field")
100+
# Strategy to sort the provided datasets
101+
sort_strategy = datasets_config.pop("sort_strategy", "random")
104102

105-
try:
106-
ratio = int(config.get("select_ratio", 1))
107-
except ValueError:
108-
logger.error(
109-
f"Invalid 'select_ratio' for dataset '{name}', using default 1"
110-
)
111-
ratio = 1
103+
# List to store each Dataset
104+
datasets = []
112105

113-
if not file_name or not prompt_field:
114-
logger.error(
115-
f"Missing required 'file_name' or 'prompt_field' for dataset '{name}'"
116-
)
117-
continue
106+
for name, config in datasets_config.items():
107+
file_name = config.get("file_name")
108+
prompt_field = config.get("prompt_field")
118109

119-
file_path = (
120-
os.path.abspath(file_name)
121-
if os.path.exists(file_name)
122-
else os.path.join(DEFAULT_DATASET_FOLDER, file_name)
110+
try:
111+
ratio = int(config.get("select_ratio", 1))
112+
except ValueError:
113+
logger.error(
114+
f"Invalid 'select_ratio' for dataset '{name}', using default 1"
123115
)
116+
ratio = 1
124117

125-
# Load dataset from local files
126-
if os.path.exists(file_path):
127-
prompts = []
128-
# CSV files
129-
if file_name.endswith(".csv"):
130-
data = pd.read_csv(file_path)
118+
if not file_name or not prompt_field:
119+
logger.error(
120+
f"Missing required 'file_name' or 'prompt_field' for dataset '{name}'"
121+
)
122+
continue
123+
124+
os.makedirs(DEFAULT_DATASET_FOLDER, exist_ok=True)
125+
126+
file_path = (
127+
os.path.abspath(file_name)
128+
if os.path.exists(file_name)
129+
else os.path.join(DEFAULT_DATASET_FOLDER, file_name)
130+
)
131131

132-
if prompt_field not in set(data.columns):
132+
# Load dataset from local files
133+
if os.path.exists(file_path):
134+
prompts = []
135+
# CSV files
136+
if file_name.endswith(".csv"):
137+
data = pd.read_csv(file_path)
138+
139+
if prompt_field not in set(data.columns):
140+
logger.error(
141+
f"Field '{prompt_field}' not found in '{file_path}'."
142+
)
143+
continue
144+
prompts = data[prompt_field].dropna().astype(str).tolist()
145+
# JSON files
146+
elif file_name.endswith(".json"):
147+
with open(file_path, "r") as f:
148+
data = json.load(f)
149+
150+
if isinstance(data, dict):
151+
prompts = data.get(prompt_field, [])
152+
if not isinstance(prompts, list):
133153
logger.error(
134-
f"Field '{prompt_field}' not found in '{file_path}'."
154+
f"Field '{prompt_field}' in '{file_path}' is not a list."
135155
)
136156
continue
137-
prompts = data[prompt_field].dropna().astype(str).tolist()
138-
# JSON files
139-
elif file_name.endswith(".json"):
140-
with open(file_path, "r") as f:
141-
data = json.load(f)
142-
143-
if isinstance(data, dict):
144-
prompts = data.get(prompt_field, [])
145-
if not isinstance(prompts, list):
146-
logger.error(
147-
f"Field '{prompt_field}' in '{file_path}' is not a list."
148-
)
149-
continue
150-
else:
151-
logger.error(f"Unsupported file format for '{file_name}'")
152-
continue
153157
else:
154-
try:
155-
if file_name.endswith(".csv"): # CSV format
156-
data = pd.read_csv(file_name)
157-
158-
if prompt_field not in set(data.columns):
159-
logger.error(
160-
f"Field '{prompt_field}' not found in '{file_name}'."
161-
)
162-
continue
163-
prompts = (
164-
data[prompt_field].dropna().astype(str).tolist()
158+
logger.error(f"Unsupported file format for '{file_name}'")
159+
continue
160+
else:
161+
try:
162+
if file_name.endswith(".csv"): # CSV format
163+
data = pd.read_csv(file_name)
164+
165+
if prompt_field not in set(data.columns):
166+
logger.error(
167+
f"Field '{prompt_field}' not found in '{file_name}'."
165168
)
166-
else: # use datasets to load
167-
data = load_dataset(file_name)["train"]
168-
prompts = []
169-
for row in data[prompt_field]:
170-
prompts.extend(normalize_prompts(row))
171-
except Exception as e:
172-
logger.error(f"Failed to load '{file_name}': {e}")
173-
174-
# Add the dataset information (file name, a list of prompts, select ratio among all datasets, total number of prompts)
175-
dataset_obj = Dataset(file_name, prompts, ratio, len(prompts))
176-
datasets.append(dataset_obj)
177-
178-
logger.info(
179-
f"loaded {file_name} with {len(prompts)} prompts, selection ratio = {ratio}"
180-
)
169+
continue
170+
prompts = (
171+
data[prompt_field].dropna().astype(str).tolist()
172+
)
173+
else: # use datasets to load
174+
data = load_dataset(file_name)["train"]
175+
prompts = []
176+
for row in data[prompt_field]:
177+
prompts.extend(normalize_prompts(row))
178+
except Exception as e:
179+
logger.error(f"Failed to load '{file_name}': {e}")
180+
181+
# Add the dataset information (file name, a list of prompts, select ratio among all datasets, total number of prompts)
182+
dataset_obj = Dataset(file_name, prompts, ratio, len(prompts))
183+
datasets.append(dataset_obj)
184+
185+
logger.info(
186+
f"loaded {file_name} with {len(prompts)} prompts, selection ratio = {ratio}"
187+
)
181188

182-
return datasets, sort_strategy
189+
return datasets, sort_strategy
183190

184-
else:
185-
logger.error("Customized data loading logic needs to be implemented!")
186-
return [], None
191+

tracestorm/request_generator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def generate_request(
1313
nums: int,
1414
messages: str = DEFAULT_MESSAGES,
1515
datasets: List[Dataset] = [],
16-
sort: str = "random",
16+
sort_strategy: str = "random",
1717
seed: int = None,
1818
) -> List[Dict[str, Any]]:
1919
# generate default requests without datasets
@@ -58,14 +58,14 @@ def generate_request(
5858
)
5959

6060
# 1. Randomly sort the requests
61-
if sort == "random":
61+
if sort_strategy == "random":
6262
if seed is not None:
6363
random.seed(seed)
6464
random.shuffle(dataset_samples)
65-
elif sort == "original": # 2. original order
65+
elif sort_strategy == "original": # 2. original order
6666
dataset_samples.sort(key=lambda x: x[0])
6767
else:
68-
raise ValueError(f"Unknown sort strategy: {sort}")
68+
raise ValueError(f"Unknown sorting strategy: {sort_strategy}")
6969

7070
# Extract the prompts from the list
7171
requests = [

0 commit comments

Comments
 (0)