Skip to content

Commit 2bc0605

Browse files
authored
Merge pull request #224 from Mundi-Xu/datasets-optimize
refactor: standardize CSV loading from ./datasets and improve robustness
2 parents 444f908 + 335787d commit 2bc0605

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

agentic_security/probe_data/data.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -248,13 +248,13 @@ def load_jailbreak_v28k() -> ProbeDataset:
248248
@cache_to_disk()
249249
def load_local_csv() -> ProbeDataset:
250250
"""Load prompts from local CSV files."""
251-
csv_files = [f for f in os.listdir(".") if f.endswith(".csv")]
251+
csv_files = [f for f in os.listdir("./datasets") if f.endswith(".csv")]
252252
logger.info(f"Found {len(csv_files)} CSV files: {csv_files}")
253253

254254
prompts = []
255255
for file in csv_files:
256256
try:
257-
df = pd.read_csv(file)
257+
df = pd.read_csv(os.path.join("./datasets", file), encoding_errors="ignore")
258258
if "prompt" in df.columns:
259259
prompts.extend(df["prompt"].tolist())
260260
else:
@@ -270,7 +270,7 @@ def load_csv(file: str) -> ProbeDataset:
270270
"""Load prompts from local CSV files."""
271271
prompts = []
272272
try:
273-
df = pd.read_csv(file)
273+
df = pd.read_csv(os.path.join("./datasets", file), encoding_errors="ignore")
274274
prompts = df["prompt"].tolist()
275275
if "prompt" in df.columns:
276276
prompts.extend(df["prompt"].tolist())
@@ -284,14 +284,14 @@ def load_csv(file: str) -> ProbeDataset:
284284
@cache_to_disk(1)
285285
def load_local_csv_files() -> list[ProbeDataset]:
286286
"""Load prompts from local CSV files and return a list of ProbeDataset objects."""
287-
csv_files = [f for f in os.listdir(".") if f.endswith(".csv")]
287+
csv_files = [f for f in os.listdir("./datasets") if f.endswith(".csv")]
288288
logger.info(f"Found {len(csv_files)} CSV files: {csv_files}")
289289

290290
datasets = []
291291

292292
for file in csv_files:
293293
try:
294-
df = pd.read_csv(file)
294+
df = pd.read_csv(os.path.join("./datasets", file), encoding_errors="ignore")
295295
if "prompt" in df.columns:
296296
prompts = df["prompt"].tolist()
297297
datasets.append(create_probe_dataset(file, prompts, {"src": file}))
@@ -336,12 +336,18 @@ def transform(self) -> Iterator[ProbeDataset]:
336336
)
337337
continue
338338

339-
size = len(dataset.prompts)
339+
# Convert async generator (if any) to a list
340+
prompts = list(dataset.prompts) if hasattr(dataset.prompts, '__iter__') else []
341+
size = len(prompts)
342+
343+
if size == 0:
344+
continue
345+
340346
sample_size = max(1, int(size * 0.25)) # At least 1 sample
341347

342348
for name, fn in self.TRANSFORMATIONS.items():
343349
logger.info(f"Applying {name} to {dataset.dataset_name}")
344-
sampled = random.sample(dataset.prompts, min(sample_size, size))
350+
sampled = random.sample(prompts, min(sample_size, size))
345351
try:
346352
transformed = [
347353
f"Follow the instructions in {name}: {fn(str(p))}"

0 commit comments

Comments
 (0)