- 
                Notifications
    
You must be signed in to change notification settings  - Fork 190
 
[Eagle Offline] multinode support for hidden states dumper #422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 
          
            
          
           | 
    @@ -20,10 +20,10 @@ | |||||||||||||||||||
| os.environ["TLLM_LOG_LEVEL"] = "error" | ||||||||||||||||||||
| import argparse | ||||||||||||||||||||
| import asyncio | ||||||||||||||||||||
| import json | ||||||||||||||||||||
| from pathlib import Path | ||||||||||||||||||||
| 
     | 
||||||||||||||||||||
| import torch | ||||||||||||||||||||
| from datasets import load_dataset | ||||||||||||||||||||
| from tensorrt_llm import LLM, SamplingParams | ||||||||||||||||||||
| from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig, SaveHiddenStatesDecodingConfig | ||||||||||||||||||||
| from tqdm import tqdm as tqdm | ||||||||||||||||||||
| 
          
            
          
           | 
    @@ -59,12 +59,10 @@ def parse_args() -> argparse.Namespace: | |||||||||||||||||||
| 
     | 
||||||||||||||||||||
| ## I/O Parameters ## | ||||||||||||||||||||
| parser.add_argument( | ||||||||||||||||||||
| "--input-file", | ||||||||||||||||||||
| "--input-data", | ||||||||||||||||||||
| type=Path, | ||||||||||||||||||||
| required=True, | ||||||||||||||||||||
| help="""Path to the input `jsonl` file containing conversations. | ||||||||||||||||||||
| Each entry must have a unique `conversation_id` field and a `conversations` field | ||||||||||||||||||||
| containing a list of messages.""", | ||||||||||||||||||||
| help="""Path to the `jsonl` file or directory containing `jsonl` files.""", | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| parser.add_argument( | ||||||||||||||||||||
| "--output-dir", | ||||||||||||||||||||
| 
        
          
        
         | 
    @@ -84,7 +82,13 @@ def parse_args() -> argparse.Namespace: | |||||||||||||||||||
| "--dp-rank", | ||||||||||||||||||||
| type=int, | ||||||||||||||||||||
| default=0, | ||||||||||||||||||||
| help="""Data parallel rank.""", | ||||||||||||||||||||
| help="""Data parallel rank. TASK_ID on SLURM.""", | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| parser.add_argument( | ||||||||||||||||||||
| "--dp-world-size", | ||||||||||||||||||||
| type=int, | ||||||||||||||||||||
| default=1, | ||||||||||||||||||||
| help="""Data parallel world size. Number of tasks on SLURM.""", | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| parser.add_argument( | ||||||||||||||||||||
| "--use-cuda-graph", | ||||||||||||||||||||
| 
        
          
        
         | 
    @@ -101,21 +105,21 @@ def parse_args() -> argparse.Namespace: | |||||||||||||||||||
| # moe_ep * moe_tp * moe_cp should be equal to tp | ||||||||||||||||||||
| # REF: https://nvidia.github.io/TensorRT-LLM/advanced/expert-parallelism.html | ||||||||||||||||||||
| parser.add_argument( | ||||||||||||||||||||
| "--moe_ep", | ||||||||||||||||||||
| "--moe-ep", | ||||||||||||||||||||
| type=int, | ||||||||||||||||||||
| default=1, | ||||||||||||||||||||
| default=None, | ||||||||||||||||||||
| help="""moe_expert_parallel_size for TRTLLM.""", | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| parser.add_argument( | ||||||||||||||||||||
| "--moe_tp", | ||||||||||||||||||||
| "--moe-tp", | ||||||||||||||||||||
| type=int, | ||||||||||||||||||||
| default=1, | ||||||||||||||||||||
| default=None, | ||||||||||||||||||||
| help="""moe_tensor_parallel_size for TRTLLM.""", | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| parser.add_argument( | ||||||||||||||||||||
| "--moe_cp", | ||||||||||||||||||||
| "--moe-cp", | ||||||||||||||||||||
| type=int, | ||||||||||||||||||||
| default=1, | ||||||||||||||||||||
| default=None, | ||||||||||||||||||||
| help="""moe_cluster_parallel_size for TRTLLM.""", | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| 
     | 
||||||||||||||||||||
| 
        
          
        
         | 
    @@ -124,28 +128,43 @@ def parse_args() -> argparse.Namespace: | |||||||||||||||||||
| 
     | 
||||||||||||||||||||
| def main(args: argparse.Namespace) -> None: | ||||||||||||||||||||
| # Load conversations | ||||||||||||||||||||
| all_conversations = [] | ||||||||||||||||||||
| with args.input_file.open("r", encoding="utf-8") as f: | ||||||||||||||||||||
| all_conversations.extend([json.loads(line) for line in f if line.strip()]) | ||||||||||||||||||||
| print("Loaded", len(all_conversations), "conversations from", args.input_file) | ||||||||||||||||||||
| 
     | 
||||||||||||||||||||
| # Remove conversations whose output file already exists | ||||||||||||||||||||
| filtered_conversations = [] | ||||||||||||||||||||
| for entry in all_conversations: | ||||||||||||||||||||
| conversation_id = entry.get("conversation_id", None) | ||||||||||||||||||||
| if conversation_id is None: | ||||||||||||||||||||
| filtered_conversations.append(entry) | ||||||||||||||||||||
| continue | ||||||||||||||||||||
| if args.input_data.is_file() and str(args.input_data).endswith(".jsonl"): | ||||||||||||||||||||
| dataset = load_dataset("json", data_files=str(args.input_data), split="train") | ||||||||||||||||||||
| elif args.input_data.is_dir(): | ||||||||||||||||||||
| dataset = load_dataset( | ||||||||||||||||||||
| "json", data_files={"train": f"{args.input_data}/*.jsonl"}, split="train" | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| else: | ||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||
| f"input_data must be a .jsonl file or directory containing .jsonl files, got: {args.input_data}" | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| print(f"Loaded {len(dataset)} conversations from {args.input_data}") | ||||||||||||||||||||
| 
     | 
||||||||||||||||||||
| # Shard data | ||||||||||||||||||||
| if args.dp_world_size > 1: | ||||||||||||||||||||
| dataset = dataset.shard(num_shards=args.dp_world_size, index=args.dp_rank) | ||||||||||||||||||||
| print( | ||||||||||||||||||||
| f"Sharded dataset to {len(dataset)} conversations for DP#{args.dp_rank}/{args.dp_world_size}" | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| 
     | 
||||||||||||||||||||
| # Remove already dumped conversations | ||||||||||||||||||||
| def keep_conversation(entry): | ||||||||||||||||||||
| conversation_id = entry.get("conversation_id", entry.get("uuid", None)) | ||||||||||||||||||||
| assert conversation_id is not None, "conversation_id is required" | ||||||||||||||||||||
| output_file = args.output_dir / f"{conversation_id}.pt" | ||||||||||||||||||||
| if output_file.exists(): | ||||||||||||||||||||
| continue | ||||||||||||||||||||
| filtered_conversations.append(entry) | ||||||||||||||||||||
| return not output_file.exists() | ||||||||||||||||||||
| 
     | 
||||||||||||||||||||
| original_num = len(dataset) | ||||||||||||||||||||
| dataset = dataset.filter(keep_conversation) | ||||||||||||||||||||
| print( | ||||||||||||||||||||
| "Removed", | ||||||||||||||||||||
| len(all_conversations) - len(filtered_conversations), | ||||||||||||||||||||
| original_num - len(dataset), | ||||||||||||||||||||
| "conversations due to existing output files", | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| all_conversations = filtered_conversations | ||||||||||||||||||||
| 
     | 
||||||||||||||||||||
| # For debugging | ||||||||||||||||||||
| if args.debug_max_num_conversations is not None: | ||||||||||||||||||||
| dataset = dataset.select(range(args.debug_max_num_conversations)) | ||||||||||||||||||||
| 
     | 
||||||||||||||||||||
| 
         
      Comment on lines
    
      +166
     to 
      168
    
   
  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Guard the debug cap against short datasets. If  -        dataset = dataset.select(range(args.debug_max_num_conversations))
+        limit = min(args.debug_max_num_conversations, len(dataset))
+        dataset = dataset.select(range(limit))📝 Committable suggestion
 
        Suggested change
       
    
 🤖 Prompt for AI Agents | 
||||||||||||||||||||
| # Get model config and tokenizer | ||||||||||||||||||||
| model_config = AutoConfig.from_pretrained(args.model) | ||||||||||||||||||||
| 
          
            
          
           | 
    @@ -187,10 +206,7 @@ def main(args: argparse.Namespace) -> None: | |||||||||||||||||||
| num_skipped_too_long = 0 | ||||||||||||||||||||
| num_invalid = 0 | ||||||||||||||||||||
| num_success = 0 | ||||||||||||||||||||
| num_total_conversations = min( | ||||||||||||||||||||
| len(all_conversations), args.debug_max_num_conversations or len(all_conversations) | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| pbar = tqdm(total=num_total_conversations, desc=f"DP#{args.dp_rank} Processing conversations") | ||||||||||||||||||||
| pbar = tqdm(total=len(dataset), desc=f"DP#{args.dp_rank} Processing conversations") | ||||||||||||||||||||
| 
     | 
||||||||||||||||||||
| def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id: int): | ||||||||||||||||||||
| """Post-process the TRTLLM dumped file to same format as HF dumped: | ||||||||||||||||||||
| 
          
            
          
           | 
    @@ -234,15 +250,17 @@ async def submit_generates(): | |||||||||||||||||||
| nonlocal num_skipped_too_long | ||||||||||||||||||||
| nonlocal num_invalid | ||||||||||||||||||||
| tasks = [] | ||||||||||||||||||||
| for idx, entry in enumerate(all_conversations[: args.debug_max_num_conversations]): | ||||||||||||||||||||
| conversation_id = entry.get("conversation_id", "{:08d}".format(idx)) | ||||||||||||||||||||
| for idx, entry in enumerate(dataset): | ||||||||||||||||||||
| conversation_id = entry.get("conversation_id", entry.get("uuid")) | ||||||||||||||||||||
| 
     | 
||||||||||||||||||||
| conversations = entry["conversations"] | ||||||||||||||||||||
| if not conversations or not isinstance(conversations, list): | ||||||||||||||||||||
| num_invalid += 1 | ||||||||||||||||||||
| continue | ||||||||||||||||||||
| 
     | 
||||||||||||||||||||
| input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False) | ||||||||||||||||||||
| input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False)[ | ||||||||||||||||||||
| :256 | ||||||||||||||||||||
| ] | ||||||||||||||||||||
| 
         
      Comment on lines
    
      +261
     to 
      +263
    
   
  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion | 🟠 Major Clarify the magic number 256 for tokenization truncation. The hardcoded limit of 256 tokens appears arbitrary and may not align with the model's actual context window or the user's  Consider one of the following approaches: 
 +parser.add_argument(
+    "--max-input-tokens",
+    type=int,
+    default=256,
+    help="Maximum number of tokens to use from conversation input for context."
+)
 +# Limit input tokens to reduce memory usage during hidden state collection
+MAX_INPUT_TOKENS = 256
+
 input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False)[
-    :256
+    :MAX_INPUT_TOKENS
 ]📝 Committable suggestion
 
        Suggested change
       
    
 🤖 Prompt for AI Agents | 
||||||||||||||||||||
| num_input_tokens = ( | ||||||||||||||||||||
| input_ids.shape[1] if isinstance(input_ids, torch.Tensor) else len(input_ids) | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| 
        
          
        
         | 
    @@ -262,12 +280,10 @@ async def submit_generates(): | |||||||||||||||||||
| if num_invalid > 0: | ||||||||||||||||||||
| print(f"Skipped {num_invalid} invalid conversations without proper fields.") | ||||||||||||||||||||
| 
     | 
||||||||||||||||||||
| if num_success == num_total_conversations: | ||||||||||||||||||||
| if num_success == len(dataset): | ||||||||||||||||||||
| print(f"Successfully processed all {num_success} conversations.") | ||||||||||||||||||||
| else: | ||||||||||||||||||||
| print( | ||||||||||||||||||||
| f"Successfully processed {num_success} out of {num_total_conversations} conversations." | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| print(f"Successfully processed {num_success} out of {len(dataset)} conversations.") | ||||||||||||||||||||
| 
     | 
||||||||||||||||||||
| 
     | 
||||||||||||||||||||
| if __name__ == "__main__": | ||||||||||||||||||||
| 
          
            
          
           | 
    ||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,48 @@ | ||||||||||||||||||||||||||||||||
| #!/bin/bash | ||||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||||
| # CHANGE THE FOLLOWING TO YOUR ACCOUNT AND CHANGE THE JOB NAME TO COMPLY WITH THE | ||||||||||||||||||||||||||||||||
| # USAGE. SWITCH TO `-p luna -t 04:00:00` IF YOU HAVE BEEN GRANTED CAPACITY FROM | ||||||||||||||||||||||||||||||||
| # THE BIWEEKLY CAPACITY MEETING. IF YOU DON'T KNOW WHO IS THE PIC OF YOUR CSRG PPP | ||||||||||||||||||||||||||||||||
| # MANAGEMET, GO WITH `-p backfill -t 00:25:00`. | ||||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||||
| #SBATCH -A coreai_dlalgo_modelopt | ||||||||||||||||||||||||||||||||
| #SBATCH --job-name=coreai_dlalgo_modelopt-mcore.modelopt | ||||||||||||||||||||||||||||||||
                
      
                  h-guo18 marked this conversation as resolved.
               
              
                Outdated
          
            Show resolved
            Hide resolved
         | 
||||||||||||||||||||||||||||||||
| #SBATCH --nodes=1 --ntasks-per-node=4 --gpus-per-node=4 | ||||||||||||||||||||||||||||||||
| 
         There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Align allocation with intended topology (1 proc using 4 GPUs). You set ntasks-per-node=4 but run a single command that does TP=4. Allocate 1 task and give it 4 GPUs to avoid idle tasks and binding ambiguity. -#SBATCH --nodes=1 --ntasks-per-node=4 --gpus-per-node=4
+#SBATCH --nodes=1 --ntasks-per-node=1 --gpus-per-task=4Optionally, also make the step explicit: -timeout 235m srun -l \
+timeout 235m srun -l -n 1 \📝 Committable suggestion
 
        Suggested change
       
    
 🤖 Prompt for AI Agents | 
||||||||||||||||||||||||||||||||
| #SBATCH -p batch | ||||||||||||||||||||||||||||||||
| #SBATCH -t 04:00:00 | ||||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||||
| echo "SLURM_ARRAY_TASK_ID: $SLURM_ARRAY_TASK_ID" | ||||||||||||||||||||||||||||||||
| echo "SLURM_ARRAY_TASK_COUNT: $SLURM_ARRAY_TASK_COUNT" | ||||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||||
| CONTAINER="nvcr.io#nvidia/tensorrt-llm/release:1.2.0rc0" | ||||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||||
| INPUT_DIR="<Can be directory containing the .jsonl files, or path to single .jsonl file>" | ||||||||||||||||||||||||||||||||
| DUMP_DIR="<Directory for output hidden states>" | ||||||||||||||||||||||||||||||||
| MODELOPT_DIR="<Path to Modelopt repo>" | ||||||||||||||||||||||||||||||||
| TEACHER_MODEL="<Path to teacher model>" | ||||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||||
| if [ ! -d "$DUMP_DIR" ]; then | ||||||||||||||||||||||||||||||||
| mkdir -p "$DUMP_DIR" | ||||||||||||||||||||||||||||||||
| fi | ||||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||||
| MOUNTS=$INPUT_DIR:/input,$DUMP_DIR:/output,$MODELOPT_DIR:/modelopt,$TEACHER_MODEL:/model | ||||||||||||||||||||||||||||||||
| rm -rf $DUMP_DIR/* | ||||||||||||||||||||||||||||||||
                
       | 
||||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||||
| #By default: TP inside node, and DP across slurm array | ||||||||||||||||||||||||||||||||
| #EP optionally available by setting --moe-ep-size and --moe-tp-size. See compute_hidden_states_trtllm.py. | ||||||||||||||||||||||||||||||||
| PARALLEL_ARGS="--tp 4 --dp-rank $SLURM_ARRAY_TASK_ID --dp-world-size $SLURM_ARRAY_TASK_COUNT" | ||||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||||
| 
         
      Comment on lines
    
      +31
     to 
      +33
    
   
  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Provide safe defaults when not using an array job. If SLURM_ARRAY_TASK_ID/COUNT are unset, dp args become empty and the Python launcher may fail. Default to a single‑rank DP. -#By default: TP inside node, and DP across slurm array
-#EP optionally available by setting --moe-ep-size and --moe-tp-size. See compute_hidden_states_trtllm.py.
-PARALLEL_ARGS="--tp 4 --dp-rank $SLURM_ARRAY_TASK_ID --dp-world-size $SLURM_ARRAY_TASK_COUNT"
+# By default: TP inside node, and DP across slurm array
+# EP optionally available by setting --moe-ep-size and --moe-tp-size. See compute_hidden_states_trtllm.py.
+DP_RANK="${SLURM_ARRAY_TASK_ID:-0}"
+DP_WORLD="${SLURM_ARRAY_TASK_COUNT:-1}"
+PARALLEL_ARGS="--tp 4 --dp-rank ${DP_RANK} --dp-world-size ${DP_WORLD}"🤖 Prompt for AI Agents | 
||||||||||||||||||||||||||||||||
| RUN_DUMPER="export TLLM_LOG_LEVEL="error"; | ||||||||||||||||||||||||||||||||
| 
         There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix shell quoting issue. The line has incorrect nested quoting that will result in the string  Apply this diff: -RUN_DUMPER="export TLLM_LOG_LEVEL="error";
+RUN_DUMPER="export TLLM_LOG_LEVEL=error;
 trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \
   --model /model \
   --input-data /input/ \
   --output-dir /output \
   $PARALLEL_ARGS \
   "🧰 Tools🪛 Shellcheck (0.11.0)[warning] 34-34: Word is of the form "A"B"C" (B indicated). Did you mean "ABC" or "A"B"C"? (SC2140) 🤖 Prompt for AI AgentsThere was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems legit. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ChenhanYu, thank you for confirming! The fix will prevent the environment variable from receiving the literal string   | 
||||||||||||||||||||||||||||||||
| trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \ | ||||||||||||||||||||||||||||||||
| --model /model \ | ||||||||||||||||||||||||||||||||
| --input-data /input/ \ | ||||||||||||||||||||||||||||||||
| --output-dir /output \ | ||||||||||||||||||||||||||||||||
| $PARALLEL_ARGS \ | ||||||||||||||||||||||||||||||||
| " | ||||||||||||||||||||||||||||||||
| 
         
      Comment on lines
    
      +34
     to 
      +40
    
   
  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Repair RUN_DUMPER quoting so the script actually runs. The current assignment ends the string at the inner  -RUN_DUMPER="export TLLM_LOG_LEVEL="error";
-trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \
-  --model /model \
-  --input-data /input/ \
-  --output-dir /output \
-  $PARALLEL_ARGS \
-  "
+read -r -d '' RUN_DUMPER <<EOF
+export TLLM_LOG_LEVEL="error"
+trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \
+  --model /model \
+  --input-data /input/ \
+  --output-dir /output \
+  $PARALLEL_ARGS
+EOF📝 Committable suggestion
 
        Suggested change
       
    
 🧰 Tools🪛 Shellcheck (0.11.0)[warning] 35-35: Word is of the form "A"B"C" (B indicated). Did you mean "ABC" or "A"B"C"? (SC2140) 🤖 Prompt for AI Agents
      Comment on lines
    
      +34
     to 
      +40
    
   
  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix shell quoting syntax error. The export statement has incorrect nested quotes that will cause a shell syntax error. As per static analysis Apply this diff to fix the quoting: -RUN_DUMPER="export TLLM_LOG_LEVEL="error";
+RUN_DUMPER="export TLLM_LOG_LEVEL=error;
 trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \
   --model /model \
   --input-data /input/ \
   --output-dir /output \
   $PARALLEL_ARGS \
   "🧰 Tools🪛 Shellcheck (0.11.0)[warning] 34-34: Word is of the form "A"B"C" (B indicated). Did you mean "ABC" or "A"B"C"? (SC2140) 🤖 Prompt for AI Agents | 
||||||||||||||||||||||||||||||||
| 
     | 
||||||||||||||||||||||||||||||||
| timeout 235m srun -l \ | ||||||||||||||||||||||||||||||||
| --mpi=pmix --overlap \ | ||||||||||||||||||||||||||||||||
| --output=%x_%j_$DATETIME.log \ | ||||||||||||||||||||||||||||||||
| --container-image ${CONTAINER} \ | ||||||||||||||||||||||||||||||||
| --container-mounts ${MOUNTS} \ | ||||||||||||||||||||||||||||||||
| bash -c "$RUN_DUMPER" | ||||||||||||||||||||||||||||||||
| 
         
      Comment on lines
    
      +42
     to 
      +47
    
   
  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix undefined variable in log filename. Line 44 references  Define the variable before use or remove it from the log filename: +DATETIME=$(date +%Y%m%d_%H%M%S)
+
 timeout 235m srun -l \
     --mpi=pmix --overlap \
     --output=%x_%j_$DATETIME.log \
     --container-image ${CONTAINER} \
     --container-mounts ${MOUNTS} \
     bash -c "$RUN_DUMPER"Alternatively, if timestamps aren't needed, simplify to:  timeout 235m srun -l \
     --mpi=pmix --overlap \
-    --output=%x_%j_$DATETIME.log \
+    --output=%x_%j.log \
     --container-image ${CONTAINER} \
     --container-mounts ${MOUNTS} \
     bash -c "$RUN_DUMPER"🤖 Prompt for AI Agents | 
||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace assertion with proper error handling.
The
keep_conversationfilter function uses an assertion to validate thatconversation_idexists, which will cause the entire process to crash if any conversation lacks this field. In a distributed DP setting, this would fail the entire SLURM job.Replace the assertion with proper error handling:
def keep_conversation(entry): conversation_id = entry.get("conversation_id", entry.get("uuid", None)) - assert conversation_id is not None, "conversation_id is required" + if conversation_id is None: + return False # Skip conversations without valid ID output_file = args.output_dir / f"{conversation_id}.pt" return not output_file.exists()Additionally, consider logging a warning when conversations are skipped due to missing IDs, similar to how other invalid conversations are tracked.
📝 Committable suggestion
🤖 Prompt for AI Agents