|
| 1 | +import argparse |
| 2 | +import json |
| 3 | +import os |
| 4 | +import sys |
| 5 | + |
| 6 | +from concurrent.futures import ThreadPoolExecutor, as_completed |
| 7 | +from datetime import datetime, timezone |
| 8 | + |
| 9 | +from dotenv import load_dotenv |
| 10 | +from tqdm import tqdm |
| 11 | + |
| 12 | + |
| 13 | +ROOT_DIR = os.path.dirname( |
| 14 | + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| 15 | +) |
| 16 | +EVAL_SCRIPTS_DIR = os.path.join(ROOT_DIR, "evaluation", "scripts") |
| 17 | + |
| 18 | +sys.path.insert(0, ROOT_DIR) |
| 19 | +sys.path.insert(0, EVAL_SCRIPTS_DIR) |
| 20 | + |
| 21 | + |
| 22 | +# All LongBench datasets |
| 23 | +LONGBENCH_DATASETS = [ |
| 24 | + "narrativeqa", |
| 25 | + "qasper", |
| 26 | + "multifieldqa_en", |
| 27 | + "multifieldqa_zh", |
| 28 | + "hotpotqa", |
| 29 | + "2wikimqa", |
| 30 | + "musique", |
| 31 | + "dureader", |
| 32 | + "gov_report", |
| 33 | + "qmsum", |
| 34 | + "multi_news", |
| 35 | + "vcsum", |
| 36 | + "trec", |
| 37 | + "triviaqa", |
| 38 | + "samsum", |
| 39 | + "lsht", |
| 40 | + "passage_count", |
| 41 | + "passage_retrieval_en", |
| 42 | + "passage_retrieval_zh", |
| 43 | + "lcc", |
| 44 | + "repobench-p", |
| 45 | +] |
| 46 | + |
| 47 | + |
| 48 | +def ingest_sample(client, sample, dataset_name, sample_idx, frame, version): |
| 49 | + """Ingest a single LongBench sample as memories.""" |
| 50 | + user_id = f"longbench_{dataset_name}_{sample_idx}_{version}" |
| 51 | + conv_id = f"longbench_{dataset_name}_{sample_idx}_{version}" |
| 52 | + |
| 53 | + # Get context and convert to messages |
| 54 | + context = sample.get("context", "") |
| 55 | + # not used now: input_text = sample.get("input", "") |
| 56 | + |
| 57 | + # For memos, we ingest the context as document content |
| 58 | + # Split context into chunks if it's too long (optional, memos handles this internally) |
| 59 | + # For now, we'll ingest the full context as a single message |
| 60 | + messages = [ |
| 61 | + { |
| 62 | + "role": "assistant", |
| 63 | + "content": context, |
| 64 | + "chat_time": datetime.now(timezone.utc).isoformat(), |
| 65 | + } |
| 66 | + ] |
| 67 | + |
| 68 | + if "memos-api" in frame: |
| 69 | + try: |
| 70 | + client.add(messages=messages, user_id=user_id, conv_id=conv_id, batch_size=1) |
| 71 | + print(f"✅ [{frame}] Ingested sample {sample_idx} from {dataset_name}") |
| 72 | + return True |
| 73 | + except Exception as e: |
| 74 | + print(f"❌ [{frame}] Error ingesting sample {sample_idx} from {dataset_name}: {e}") |
| 75 | + return False |
| 76 | + elif "mem0" in frame: |
| 77 | + timestamp = int(datetime.now(timezone.utc).timestamp()) |
| 78 | + try: |
| 79 | + client.add(messages=messages, user_id=user_id, timestamp=timestamp, batch_size=1) |
| 80 | + print(f"✅ [{frame}] Ingested sample {sample_idx} from {dataset_name}") |
| 81 | + return True |
| 82 | + except Exception as e: |
| 83 | + print(f"❌ [{frame}] Error ingesting sample {sample_idx} from {dataset_name}: {e}") |
| 84 | + return False |
| 85 | + elif frame == "memobase": |
| 86 | + for m in messages: |
| 87 | + m["created_at"] = messages[0]["chat_time"] |
| 88 | + try: |
| 89 | + client.add(messages=messages, user_id=user_id, batch_size=1) |
| 90 | + print(f"✅ [{frame}] Ingested sample {sample_idx} from {dataset_name}") |
| 91 | + return True |
| 92 | + except Exception as e: |
| 93 | + print(f"❌ [{frame}] Error ingesting sample {sample_idx} from {dataset_name}: {e}") |
| 94 | + return False |
| 95 | + elif frame == "memu": |
| 96 | + try: |
| 97 | + client.add(messages=messages, user_id=user_id, iso_date=messages[0]["chat_time"]) |
| 98 | + print(f"✅ [{frame}] Ingested sample {sample_idx} from {dataset_name}") |
| 99 | + return True |
| 100 | + except Exception as e: |
| 101 | + print(f"❌ [{frame}] Error ingesting sample {sample_idx} from {dataset_name}: {e}") |
| 102 | + return False |
| 103 | + elif frame == "supermemory": |
| 104 | + try: |
| 105 | + client.add(messages=messages, user_id=user_id) |
| 106 | + print(f"✅ [{frame}] Ingested sample {sample_idx} from {dataset_name}") |
| 107 | + return True |
| 108 | + except Exception as e: |
| 109 | + print(f"❌ [{frame}] Error ingesting sample {sample_idx} from {dataset_name}: {e}") |
| 110 | + return False |
| 111 | + |
| 112 | + return False |
| 113 | + |
| 114 | + |
| 115 | +def load_dataset_from_local(dataset_name, use_e=False): |
| 116 | + """Load LongBench dataset from local JSONL file.""" |
| 117 | + # Determine data directory |
| 118 | + data_dir = os.path.join( |
| 119 | + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), |
| 120 | + "data", |
| 121 | + "long_bench_v2", |
| 122 | + ) |
| 123 | + |
| 124 | + # Determine filename |
| 125 | + filename = f"{dataset_name}_e.jsonl" if use_e else f"{dataset_name}.jsonl" |
| 126 | + |
| 127 | + filepath = os.path.join(data_dir, filename) |
| 128 | + |
| 129 | + if not os.path.exists(filepath): |
| 130 | + raise FileNotFoundError(f"Dataset file not found: {filepath}") |
| 131 | + |
| 132 | + # Load JSONL file |
| 133 | + samples = [] |
| 134 | + with open(filepath, encoding="utf-8") as f: |
| 135 | + for line in f: |
| 136 | + if line.strip(): |
| 137 | + samples.append(json.loads(line)) |
| 138 | + |
| 139 | + return samples |
| 140 | + |
| 141 | + |
| 142 | +def ingest_dataset(dataset_name, frame, version, num_workers=10, max_samples=None, use_e=False): |
| 143 | + """Ingest a single LongBench dataset.""" |
| 144 | + print(f"\n{'=' * 80}") |
| 145 | + print(f"🔄 [INGESTING DATASET: {dataset_name.upper()}]".center(80)) |
| 146 | + print(f"{'=' * 80}\n") |
| 147 | + |
| 148 | + # Load dataset from local files |
| 149 | + try: |
| 150 | + dataset = load_dataset_from_local(dataset_name, use_e) |
| 151 | + print(f"Loaded {len(dataset)} samples from {dataset_name}") |
| 152 | + except FileNotFoundError as e: |
| 153 | + print(f"❌ Error loading dataset {dataset_name}: {e}") |
| 154 | + return |
| 155 | + except Exception as e: |
| 156 | + print(f"❌ Error loading dataset {dataset_name}: {e}") |
| 157 | + return |
| 158 | + |
| 159 | + # Limit samples if specified |
| 160 | + if max_samples: |
| 161 | + dataset = dataset[:max_samples] |
| 162 | + print(f"Limited to {len(dataset)} samples") |
| 163 | + |
| 164 | + # Initialize client |
| 165 | + client = None |
| 166 | + if frame == "mem0" or frame == "mem0_graph": |
| 167 | + from utils.client import Mem0Client |
| 168 | + |
| 169 | + client = Mem0Client(enable_graph="graph" in frame) |
| 170 | + elif frame == "memos-api": |
| 171 | + from utils.client import MemosApiClient |
| 172 | + |
| 173 | + client = MemosApiClient() |
| 174 | + elif frame == "memos-api-online": |
| 175 | + from utils.client import MemosApiOnlineClient |
| 176 | + |
| 177 | + client = MemosApiOnlineClient() |
| 178 | + elif frame == "memobase": |
| 179 | + from utils.client import MemobaseClient |
| 180 | + |
| 181 | + client = MemobaseClient() |
| 182 | + elif frame == "memu": |
| 183 | + from utils.client import MemuClient |
| 184 | + |
| 185 | + client = MemuClient() |
| 186 | + elif frame == "supermemory": |
| 187 | + from utils.client import SupermemoryClient |
| 188 | + |
| 189 | + client = SupermemoryClient() |
| 190 | + else: |
| 191 | + print(f"❌ Unsupported frame: {frame}") |
| 192 | + return |
| 193 | + |
| 194 | + # Ingest samples |
| 195 | + success_count = 0 |
| 196 | + with ThreadPoolExecutor(max_workers=num_workers) as executor: |
| 197 | + futures = [] |
| 198 | + for idx, sample in enumerate(dataset): |
| 199 | + future = executor.submit( |
| 200 | + ingest_sample, client, sample, dataset_name, idx, frame, version |
| 201 | + ) |
| 202 | + futures.append(future) |
| 203 | + |
| 204 | + for future in tqdm( |
| 205 | + as_completed(futures), |
| 206 | + total=len(futures), |
| 207 | + desc=f"Ingesting {dataset_name}", |
| 208 | + ): |
| 209 | + try: |
| 210 | + if future.result(): |
| 211 | + success_count += 1 |
| 212 | + except Exception as e: |
| 213 | + print(f"Error processing sample: {e}") |
| 214 | + |
| 215 | + print(f"\n✅ Completed ingesting {dataset_name}: {success_count}/{len(dataset)} samples") |
| 216 | + return success_count |
| 217 | + |
| 218 | + |
| 219 | +def main(frame, version="default", num_workers=10, datasets=None, max_samples=None, use_e=False): |
| 220 | + """Main ingestion function.""" |
| 221 | + load_dotenv() |
| 222 | + |
| 223 | + print("\n" + "=" * 80) |
| 224 | + print(f"🚀 LONGBENCH INGESTION - {frame.upper()} v{version}".center(80)) |
| 225 | + print("=" * 80 + "\n") |
| 226 | + |
| 227 | + # Determine which datasets to process |
| 228 | + dataset_list = [d.strip() for d in datasets.split(",")] if datasets else LONGBENCH_DATASETS |
| 229 | + |
| 230 | + # Filter valid datasets |
| 231 | + valid_datasets = [d for d in dataset_list if d in LONGBENCH_DATASETS] |
| 232 | + if not valid_datasets: |
| 233 | + print("❌ No valid datasets specified") |
| 234 | + return |
| 235 | + |
| 236 | + print(f"Processing {len(valid_datasets)} datasets: {valid_datasets}\n") |
| 237 | + |
| 238 | + # Ingest each dataset |
| 239 | + total_success = 0 |
| 240 | + total_samples = 0 |
| 241 | + for dataset_name in valid_datasets: |
| 242 | + success = ingest_dataset(dataset_name, frame, version, num_workers, max_samples, use_e) |
| 243 | + if success is not None: |
| 244 | + total_success += success |
| 245 | + total_samples += max_samples if max_samples else 200 # Approximate |
| 246 | + |
| 247 | + print(f"\n{'=' * 80}") |
| 248 | + print(f"✅ INGESTION COMPLETE: {total_success} samples ingested".center(80)) |
| 249 | + print(f"{'=' * 80}\n") |
| 250 | + |
| 251 | + |
| 252 | +if __name__ == "__main__": |
| 253 | + parser = argparse.ArgumentParser() |
| 254 | + parser.add_argument( |
| 255 | + "--lib", |
| 256 | + type=str, |
| 257 | + choices=[ |
| 258 | + "mem0", |
| 259 | + "mem0_graph", |
| 260 | + "memos-api", |
| 261 | + "memos-api-online", |
| 262 | + "memobase", |
| 263 | + "memu", |
| 264 | + "supermemory", |
| 265 | + ], |
| 266 | + default="memos-api", |
| 267 | + ) |
| 268 | + parser.add_argument( |
| 269 | + "--version", |
| 270 | + type=str, |
| 271 | + default="default", |
| 272 | + help="Version identifier for saving results", |
| 273 | + ) |
| 274 | + parser.add_argument( |
| 275 | + "--workers", |
| 276 | + type=int, |
| 277 | + default=10, |
| 278 | + help="Number of parallel workers", |
| 279 | + ) |
| 280 | + parser.add_argument( |
| 281 | + "--datasets", |
| 282 | + type=str, |
| 283 | + default=None, |
| 284 | + help="Comma-separated list of datasets to process (default: all)", |
| 285 | + ) |
| 286 | + parser.add_argument( |
| 287 | + "--max_samples", |
| 288 | + type=int, |
| 289 | + default=None, |
| 290 | + help="Maximum number of samples per dataset (default: all)", |
| 291 | + ) |
| 292 | + parser.add_argument( |
| 293 | + "--e", |
| 294 | + action="store_true", |
| 295 | + help="Use LongBench-E variant (uniform length distribution)", |
| 296 | + ) |
| 297 | + args = parser.parse_args() |
| 298 | + |
| 299 | + main( |
| 300 | + args.lib, |
| 301 | + args.version, |
| 302 | + args.workers, |
| 303 | + args.datasets, |
| 304 | + args.max_samples, |
| 305 | + args.e, |
| 306 | + ) |
0 commit comments