-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathgen_emb.sh
More file actions
40 lines (37 loc) · 1.76 KB
/
gen_emb.sh
File metadata and controls
40 lines (37 loc) · 1.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#!/usr/bin/env bash
#
# Export DETree embeddings for RealBench.
#
DATA_ROOT="/path/to/RealBench" # Root directory that contains the RealBench datasets.
DATABASE_NAME="all" # Dataset alias understood by detree.utils.dataset.load_datapath.
MODEL_NAME="FacebookAI/roberta-large" # Backbone encoder identifier or local Hugging Face-style directory.
SAVE_DIR="/path/to/RealBench/embeddings" # Directory where embedding databases will be stored.
SAVE_NAME="detree_stage1" # Filename (without extension) for the saved embeddings.
DEVICE_NUM=4 # Number of CUDA devices available to Fabric.
BATCH_SIZE=64 # Inference batch size per device.
NUM_WORKERS=8 # DataLoader workers for reading datasets.
MAX_LENGTH=512 # Maximum tokenised sequence length.
POOLING="max" # Embedding pooling strategy.
NEED_LAYER=(16 17 18 19 22 23) # Hidden layers to export from the encoder.
SPLIT="train" # Dataset split to encode (train/test/extra).
# Extra CLI switches can be appended here, for example:
# EXTRA_FLAGS=(--no-adversarial)
EXTRA_FLAGS=(
# --no-adversarial
# --has-mix
)
set -euo pipefail
python -m detree.cli.embeddings \
--path "$DATA_ROOT" \
--database-name "$DATABASE_NAME" \
--model-name "$MODEL_NAME" \
--savedir "$SAVE_DIR" \
--name "$SAVE_NAME" \
--device-num "$DEVICE_NUM" \
--batch-size "$BATCH_SIZE" \
--num-workers "$NUM_WORKERS" \
--max-length "$MAX_LENGTH" \
--pooling "$POOLING" \
--need-layer "${NEED_LAYER[@]}" \
--split "$SPLIT" \
"${EXTRA_FLAGS[@]}"