Skip to content

Commit ac0071f

Browse files
dest1n1sSmallMelon-LFrankstein73crabshellmanHzfinfdu
authored
Dev (#154)
* Fix(analysis): fix analysis on single card and multi card * Revert TransformerLens submodule to previous commit * feat(generate): add override_dtype setting to control activation dtype in GenerateActivationsSettings * feat(distributed): add get_process_group utility, trying to fix checkpoint saving during sweeps. * Feat(examples): add training examples * refactor: use TensorSpecs rather than logging method dispatch for logging with different SAE variants (#130) * misc: cleanup circuit tracing mode and some distributed utils * refactor: use TensorSpecs rather than logging method dispatch for logging with different SAE variants * misc: clean up logging method dispatch * misc: rename tensor_specs to specs * style: simplify conditions on optional field * misc(generate): override_dtype for GenerateActivationsSettings should have a default value of None or pydantic would complain about it * fix: replace `.item()` with `item(...)` to ensure distributed consistency. See: pytorch/pytorch#152406 * refactor: use Metric classes for disentangled metric computation * refactor: use Metric classes to run evaluation (#133) * feat: support resuming wandb run from training checkpoint - Add wandb_run_id and wandb_resume config options - Save wandb run id when saving checkpoint - Load trainer from checkpoint when from_pretrained_path is set * fix(activation): make mask/attention_mask on the correct device * fix(activation): use local_map for mask computation on DTensor to ensure correct device placement * fix(trainer): use ctx.get() for optional coefficients to prevent KeyError * feat(train): add checkpoint resume support for crosscoder, clt, lorsa and molt runners * fix(trainer): correct token count calculation for 2D activation in LORSA training * chore: remove default extra * misc(evaluator): add type annotation * docs: workthrough (WIP) * ci: install with all extras * fix: type errors due to torch updates on local_map * fix(TL): add support for whole qwen3 family & fix inconsistency in tie-word-embed * feat: conversion methods between lm-saes and saelens Co-authored-by: Guancheng Zhou <[email protected]> * Fix(examples): fix the activation_factory settings of lorsa examples * feat(autointerp): refactor to async & support lorsa * feat(autointerp): better parallelization with async * feat(database): show progress for database operations (add analysis & update feature) * feat(autointerp): better parallelization with async * feat(database): show progress for database operations (add analysis & update feature) * misc(ruff): fix ruff & typecheck errors * feat(autointerp): update ui to support autointerp wo verification * fix(format): fix pyright issues * fix(format): fix pyright issues * fix(misc): remove try-except logics for progress measure in autointerp * feat(autointerp): support max suppressing logits in autointerp * feature(autointerp): improved autointerp prompts and support lorsa autointerp with z pattern * fix(misc): ruff for autointerp * fix(misc): ruff for autointerp * refactor: use tanstack start for frontend; make a more neuronpedia-like ui (#146) * fix(database): deal with none value * feat(ui): support paged queries of samples * feat(ui): set scrollbar-gutter to ensure space reserved for scrollbar to prevent layout shiftingg * style(ui): fix eslint & prettier * feat(ui): interpretation with real data * ci(ui): add eslint & prettier check * chore: fix pre-commit for both python and typescript * format(ui): fix eslint & prettier * format(ui): adjust eslint rules * fix: pin torch==2.8.0 for dtensor compatibility. - Pin torch version to 2.8.0 to avoid dtensor-related errors in 2.9.0 - Remove unused d_model field from LanguageModelConfig - Add GPU memory usage display in training progress bar - Move batch refresh to end of training loop iteration * feat(ui): dictionary page (WIP) * feat(ui): dictionary page * feat(ui): feature list in feature page (WIP) * fix(ui): feature list loading previous page causes wrong scroll position * fix(ui): reinitialize useFeatures hook when concerned feature index out of range * fix(ui): fix feature list height * fix(metric): support inconsistent batch size * perf(ui): fetch sample range on demand * feat(server): support preloading models/saes * feat(ui): remove in card progress bar * fix(ui): fix visible range comparison * feat(ui): adjust accent color * fix(optim): add DTensor support for SparseAdam, redistribute grad to match parameter's placements when grad is DTensor * feat(circuit): Major revision. 1. Support circuit tracing with plt+lorsa and plt only. wrap list of plts into Trancoder Set, following circuit tracer. 2. update QK tracing. Now we can see feature-feature pairwise attribution. Efficiency might require revisiting. 3. refactor attribution sturcture. Breaking down several heavy files. Ready to be further improved, mainly in reducing numerous if use_lorsa branches * feat(backend): add DTensor support to TransformerLensLanguageModel * feat(backend): add DTensor support to TransformerLensLanguageModel - Add device_mesh parameter to support distributed inference - Implement forward() with local_map for DTensor inputs - Add run_with_hooks() that wraps hooks to convert between DTensor and local tensor - Update to_activations() to return DTensor when device_mesh is set * fix(backend): convert placements list to tuple for DTensor comparison * feat(backend): add run_with_cache and hooks context manager with DTensor support * fix(backend): skip n_context sync when already provided in to_activations * fix(attribution): fix missing gradient flow configuration for lorsa QKnorm (#149) * fix(server): expose lru_cache ability from synchronized decorater * fix(lorsa): fix lorsa init * fix(lorsa): fix set decoder norm for lorsa * feat(ui): simply move the original circuit page to ui-ssr * misc(ui): remove comments * refactor(ui): split data and visual states; move up feature fetching logic * refactor(ui): remove standalone CircuitVisualization component * chore(dependencies): update torch and torch-npu versions to 2.9.0 * fix(lorsa): avoid triggering DTensor bug in torch==2.9.0 * feat(lorsa): Init lorsa with the active subspace of V. * feat(metrics): add GradientNormMetric and extend Record with reduction modes * feat(training): support training lorsa with varying lengths of training sequences. This will lead to total number of training tokens inaccurate (#150) * fix(attribution): fix missing gradient flow configuration for lorsa QK norm * fix(attribution): fix missing gradient flow configuration for lorsa QK norm * feat(config): remove all instances of use_batch_norm_mse. We do not want this from now * fix(activation): load saved mask and attention_mask * feat(training): support training lorsa with varying lengths of training sequences. This will lead to total number of training tokens inaccurate. * feat(training): support training lorsa with varying lengths of training sequences. This will lead to total number of training tokens inaccurate. * misc(lorsa): use abstract_sae computeloss; put l_rec.mean() after loss dict * fix(training): also transform batch['mask'] to Tensor from DTensor in… (#152) * fix(training): also transform batch['mask'] to Tensor from DTensor in distributed scenarios * fix(training): also transform batch['mask'] to Tensor from DTensor in distributed scenarios * feat(optim): add custom gradient norm computation and clipping for distributed training * fix: compute_loss DTensor loss shape * fix(misc): we do not want to filter out eos. It might be end of chats and include useful information * feat: circuit tracing with backend interaction * fix(server): synchronized decorator type issue * fix(backend): use TokenizerFast for trace token origins * fix(ui): better display dead feature * fix(ui): correctly display truncated z pattern * fix(ui): minor layout issues * fix(attribution): add details to some comments * perf(ui): better visual display for circuit (WIP) * fix(trainer): remove assertion for clip_grad_norm in distributed training * feat(transcoder): init transcoder with MLP. * fix(tc): fix type problem * fix(tc): fix type problem * feat(ui): hover & click nearest node * feat(analyze): make FeatureAnalyzer aware of mask * docs: update installation instructions and example README * fix(runner): type mismatch * chore: bump version to 2.0.0b4 --------- Co-authored-by: Junxuan Wang <[email protected]> Co-authored-by: frankstein <[email protected]> Co-authored-by: Jiaxing Wu <[email protected]> Co-authored-by: Zhengfu He@SII <[email protected]> Co-authored-by: Guancheng Zhou <[email protected]> Co-authored-by: Guancheng Zhou <[email protected]> Co-authored-by: StarDust73 <[email protected]>
2 parents 43f1a6d + 3f35e3d commit ac0071f

File tree

159 files changed

+13359
-3449
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

159 files changed

+13359
-3449
lines changed

.github/workflows/checks.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,15 @@ jobs:
4646
uses: actions/checkout@v4
4747
with:
4848
submodules: "true"
49+
4950
- name: Install uv
5051
uses: astral-sh/setup-uv@v5
5152
with:
5253
enable-cache: true
5354
cache-dependency-glob: "uv.lock"
5455

5556
- name: Install the project
56-
run: uv sync --extra default --dev
57+
run: uv sync --all-extras --dev
5758

5859
- name: Type check
5960
run: uv run basedpyright .
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
name: UI SSR Checks
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
- dev
8+
paths:
9+
- "ui-ssr/**"
10+
- ".github/workflows/ui-ssr-checks.yml"
11+
pull_request:
12+
branches:
13+
- main
14+
- dev
15+
paths:
16+
- "ui-ssr/**"
17+
- ".github/workflows/ui-ssr-checks.yml"
18+
19+
permissions:
20+
contents: read
21+
22+
jobs:
23+
check:
24+
runs-on: ubuntu-latest
25+
defaults:
26+
run:
27+
working-directory: ui-ssr
28+
29+
steps:
30+
- name: Checkout repository
31+
uses: actions/checkout@v4
32+
33+
- name: Setup Bun
34+
uses: oven-sh/setup-bun@v1
35+
with:
36+
bun-version: latest
37+
38+
- name: Install dependencies
39+
run: bun install
40+
41+
- name: Run CI checks
42+
run: bun run ci:check

.pre-commit-config.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,14 @@ repos:
88
args: [--fix]
99
# Run the formatter.
1010
- id: ruff-format
11+
12+
- repo: local
13+
hooks:
14+
- id: ui-ssr-lint-staged
15+
name: ui-ssr lint-staged
16+
# Enter the directory and run lint-staged.
17+
# lint-staged will handle the staged files selection and running prettier/eslint.
18+
entry: bash -c 'cd ui-ssr && bun run check'
19+
language: system
20+
pass_filenames: false
21+
files: ^ui-ssr/

README.md

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,26 @@
3030

3131
## Installation
3232

33+
Use [pip](https://pypi.org/project/pip/) to install Language-Model-SAEs:
34+
35+
```bash
36+
pip install lm-saes==2.0.0b4
37+
```
38+
39+
We also highly recommend using [uv](https://docs.astral.sh/uv/) to manage your own project dependencies. You can use
40+
41+
```bash
42+
uv add lm-saes==2.0.0b4
43+
```
44+
45+
to add Language-Model-SAEs as your project dependency.
46+
47+
## Development
48+
3349
We use [uv](https://docs.astral.sh/uv/) to manage the dependencies, which is an alternative to [poetry](https://python-poetry.org/) or [pdm](https://pdm-project.org/). To install the required packages, just install [uv](https://docs.astral.sh/uv/getting-started/installation/), and run the following command:
3450

3551
```bash
36-
uv sync --extra default
52+
uv sync
3753
```
3854

3955
This will install all the required packages for the codebase in `.venv` directory. For Ascend NPU support, run
@@ -47,15 +63,13 @@ A forked version of `TransformerLens` is also included in the dependencies to pr
4763
If you want to use the visualization tools, you also need to install the required packages for the frontend, which uses [bun](https://bun.sh/) for dependency management. Follow the instructions on the website to install it, and then run the following command:
4864

4965
```bash
50-
cd ui
66+
cd ui-ssr
5167
bun install
5268
```
5369

54-
`bun` is not well-supported on Windows, so you may need to use WSL or other Linux-based solutions to run the frontend, or consider using a different package manager, such as `pnpm` or `yarn`.
55-
5670
## Launch an Experiment
5771

58-
The guidelines and examples for launching experiments are generally outdated. At this moment, you may explore `src/lm_saes/runners` folder for the interface for generating activations and training & analyzing SAE variants. For analyzing SAEs, a MongoDB instance is required. More instructions will be provided in near future.
72+
Explore the `examples` to check the basic usage of training/analyzing SAEs in different configurations. Note a MongoDB is recommended for recording the model/dataset/SAE configurations and required for storing analyses. For more advanced usage, you may explore `src/lm_saes/runners` folder for the interface for generating activations and training & analyzing SAE variants, and directly write your own variant of training/analyzing script at the runner level.
5973

6074
## Visualizing the Learned Dictionary
6175

@@ -65,7 +79,7 @@ The analysis results will be saved using MongoDB, and you can use the provided v
6579
uvicorn server.app:app --port 24577 --env-file server/.env
6680
```
6781

68-
Then, copy the `ui/.env.example` file to `ui/.env` and modify the `VITE_BACKEND_URL` to fit your server settings (by default, it's `http://localhost:24577`), and start the frontend by running the following command:
82+
Then, copy the `ui/.env.example` file to `ui/.env` and modify the `BACKEND_URL` to fit your server settings (by default, it's `http://localhost:24577`), and start the frontend by running the following command:
6983

7084
```bash
7185
cd ui

docs/index.md

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -60,40 +60,61 @@ To train a simple Sparse Autoencoder on `blocks.5.hook_resid_post` of a Pythia-1
6060
```python
6161
settings = TrainSAESettings(
6262
sae=SAEConfig(
63-
hook_point_in=f"blocks.5.hook_resid_post",
63+
hook_point_in="blocks.6.hook_resid_post",
64+
hook_point_out="blocks.6.hook_resid_post",
6465
d_model=768,
6566
expansion_factor=8,
66-
act_fn="jumprelu",
67+
act_fn="topk",
68+
top_k=50,
69+
dtype=torch.float32,
70+
device="cuda",
6771
),
6872
initializer=InitializerConfig(
6973
grid_search_init_norm=True,
7074
),
7175
trainer=TrainerConfig(
72-
lr=5e-5,
73-
l1_coefficient=0.3,
76+
lr=1e-4,
77+
initial_k=50,
78+
k_warmup_steps=0.1,
79+
k_schedule_type="linear",
7480
total_training_tokens=800_000_000,
75-
sparsity_loss_type="tanh-quad",
76-
jumprelu_lr_factor=0.1,
81+
log_frequency=1000,
82+
eval_frequency=1000000,
83+
n_checkpoints=0,
84+
check_point_save_mode="linear",
85+
exp_result_path="results",
7786
),
87+
model=LanguageModelConfig(
88+
model_name="EleutherAI/pythia-160m",
89+
device="cuda",
90+
dtype="torch.float16",
91+
),
92+
model_name="pythia-160m",
93+
datasets={
94+
"SlimPajama-3B": DatasetConfig(
95+
dataset_name_or_path="Hzfinfdu/SlimPajama-3B",
96+
)
97+
},
7898
wandb=WandbConfig(
7999
wandb_project="lm-saes",
80-
exp_name=name,
100+
exp_name="pythia-160m-sae",
81101
),
82102
activation_factory=ActivationFactoryConfig(
83103
sources=[
84-
ActivationFactoryActivationsSource(
85-
path=Path(args.activation_path).expanduser(),
86-
name=f"pythia-160m-1d",
87-
device="cuda",
88-
dtype=torch.float32,
104+
ActivationFactoryDatasetSource(
105+
name="SlimPajama-3B",
89106
)
90107
],
91108
target=ActivationFactoryTarget.ACTIVATIONS_1D,
92-
hook_points=["blocks.5.hook_resid_post"],
109+
hook_points=["blocks.6.hook_resid_post"],
93110
batch_size=4096,
94-
buffer_size=None,
111+
buffer_size=4096 * 4,
112+
buffer_shuffle=BufferShuffleConfig(
113+
perm_seed=42,
114+
generator_device="cuda",
115+
),
95116
),
96-
sae_name="L5R",
117+
sae_name="pythia-160m-sae",
97118
sae_series="pythia-sae",
98119
)
99120
train_sae(settings)

docs/workthrough.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Workthrough
2+
3+
`Language-Model-SAEs` provides a general way to train, analyze and visualize Sparse Autoencoders and their variants. To help you get started quickly, we've included [example scripts]() that guide you through each stage of working with SAEs. This guide begins with a foundational example and progressively introduces the core features and capabilities of the library.
4+
5+
## Training Basic Sparse Autoencoders
6+
7+
A [Sparse Autoencoder]() is trained to reconstruct model activations at specific position. We depend on [TransformerLens](https://github.com/TransformerLensOrg/TransformerLens) to take activations out of model forward pass, specified by hook points. To train a vanilla SAE on Pythia 160M Layer 6 output, you can create the following `TrainSAESetting`:
8+
9+
```python
10+
11+
```

examples/README.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Example setups of Language-Model-SAEs
2+
3+
The standard SAE-based pipeline of mechanistically interpreting internal representations of language models contains the following steps: Generating activations (optional) -> Training SAEs -> Analyzing SAEs -> Visualizing analyses.
4+
5+
Here present example setups of generating, training and analyzing, with variants of SAE architectures, activation functions and whether to use pre-generated activations.
6+
7+
## Use on-the-fly model activations
8+
9+
SAE training requires stream of model activations at certain hook points (i.e. specefic location of model internal representation). Model activations can either be cached ahead-of-time on the disk, or produced on the fly.
10+
11+
For on-the-fly model activation usage, the _Generating activations_ step can be skipped, and thus the overall pipeline is simplified. You can refer to [train_pythia_sae_topk](https://github.com/OpenMOSS/Language-Model-SAEs/blob/main/examples/train_pythia_sae_topk.py) and [analyze_pythia_sae](https://github.com/OpenMOSS/Language-Model-SAEs/blob/main/examples/analyze_pythia_sae.py) and other scripts without a `with_pre_generated_activations` suffix to launch the experiments on Pythia. Note the analyzing requires a MongoDB instance (default to `mongodb://localhost:27017`) running to save the analyzing results.
12+
13+
## Use cached activations
14+
15+
Cached activations are more common usage in practical SAE training and analyzing. It enables effective hyperparameter sweeping with reuse of generated activations, and also enables parallelled training and analyzing (DP/TP). However, it requires a non-trivial amount of disk space, e.g., caching 800M tokens of one layer activation of Pythia 160M requires about 6TB space.
16+
17+
To launch experiments with cached activations, you should first generate activations with 1d shape (`(batch, d_model)`, for training use), and 2d shape (`(batch, n_context, d_model)`, for analyzing use), by running [generate_pythia_activation_1d](https://github.com/OpenMOSS/Language-Model-SAEs/blob/main/examples/generate_pythia_activation_1d.py) and [generate_pythia_activation_2d](https://github.com/OpenMOSS/Language-Model-SAEs/blob/main/examples/generate_pythia_activation_2d.py). Then, you can use [train_pythia_sae_with_pre_generated_activations](https://github.com/OpenMOSS/Language-Model-SAEs/blob/main/examples/train_pythia_sae_with_pre_generated_activations.py) and [analyze_pythia_sae_with_pre_generated_activations](https://github.com/OpenMOSS/Language-Model-SAEs/blob/main/examples/analyze_pythia_sae_with_pre_generated_activations.py) to run training and analyzing respectively, with a pre-generated activation path specified. Note the analyzing still requires a MongoDB instance running.

examples/analyze_pythia_sae.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import argparse
2+
import os
3+
4+
import torch
5+
6+
from lm_saes import (
7+
ActivationFactoryConfig,
8+
ActivationFactoryDatasetSource,
9+
ActivationFactoryTarget,
10+
AnalyzeSAESettings,
11+
DatasetConfig,
12+
FeatureAnalyzerConfig,
13+
LanguageModelConfig,
14+
MongoDBConfig,
15+
SAEConfig,
16+
analyze_sae,
17+
)
18+
19+
20+
def parse_args():
21+
parser = argparse.ArgumentParser()
22+
parser.add_argument("--sae_path", type=str, required=True)
23+
return parser.parse_args()
24+
25+
26+
if __name__ == "__main__":
27+
torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0)))
28+
args = parse_args()
29+
30+
sae_cfg = SAEConfig.from_pretrained(os.path.expanduser(args.sae_path), device="cuda", dtype=torch.float16)
31+
analyze_sae(
32+
AnalyzeSAESettings(
33+
sae=sae_cfg,
34+
sae_name="pythia-160m-sae",
35+
sae_series="pythia-sae",
36+
activation_factory=ActivationFactoryConfig(
37+
sources=[
38+
ActivationFactoryDatasetSource(
39+
name="SlimPajama-3B",
40+
)
41+
],
42+
target=ActivationFactoryTarget.ACTIVATIONS_2D,
43+
hook_points=["blocks.6.hook_resid_post"],
44+
batch_size=16,
45+
context_size=2048,
46+
),
47+
model=LanguageModelConfig(
48+
model_name="EleutherAI/pythia-160m",
49+
device="cuda",
50+
dtype="torch.float16",
51+
),
52+
model_name="pythia-160m",
53+
datasets={
54+
"SlimPajama-3B": DatasetConfig(
55+
dataset_name_or_path="Hzfinfdu/SlimPajama-3B",
56+
)
57+
},
58+
analyzer=FeatureAnalyzerConfig(
59+
total_analyzing_tokens=100_000_000,
60+
subsamples={
61+
"top_activations": {"proportion": 1.0, "n_samples": 20},
62+
"subsampling_80%": {"proportion": 0.8, "n_samples": 10},
63+
"subsampling_60%": {"proportion": 0.6, "n_samples": 10},
64+
"subsampling_40%": {"proportion": 0.4, "n_samples": 10},
65+
"non_activating": {"proportion": 0.3, "n_samples": 20, "max_length": 50},
66+
},
67+
),
68+
mongo=MongoDBConfig(),
69+
device_type="cuda",
70+
)
71+
)
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import argparse
2+
import os
3+
4+
import torch
5+
6+
from lm_saes import (
7+
ActivationFactoryActivationsSource,
8+
ActivationFactoryConfig,
9+
ActivationFactoryTarget,
10+
AnalyzeSAESettings,
11+
FeatureAnalyzerConfig,
12+
MongoDBConfig,
13+
SAEConfig,
14+
analyze_sae,
15+
)
16+
17+
18+
def parse_args():
19+
parser = argparse.ArgumentParser()
20+
parser.add_argument("--sae_path", type=str, required=True)
21+
parser.add_argument("--activation_path", type=str, required=True)
22+
return parser.parse_args()
23+
24+
25+
if __name__ == "__main__":
26+
torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0)))
27+
args = parse_args()
28+
29+
sae_cfg = SAEConfig.from_pretrained(os.path.expanduser(args.sae_path), device="cuda", dtype=torch.float16)
30+
analyze_sae(
31+
AnalyzeSAESettings(
32+
sae=sae_cfg,
33+
sae_name="pythia-160m-sae",
34+
sae_series="pythia-sae",
35+
activation_factory=ActivationFactoryConfig(
36+
sources=[
37+
ActivationFactoryActivationsSource(
38+
path=str(args.activation_path),
39+
name="pythia-160m-2d",
40+
device="cuda",
41+
dtype=torch.float16,
42+
)
43+
],
44+
target=ActivationFactoryTarget.ACTIVATIONS_2D,
45+
hook_points=["blocks.6.hook_resid_post"],
46+
batch_size=16,
47+
context_size=2048,
48+
),
49+
analyzer=FeatureAnalyzerConfig(
50+
total_analyzing_tokens=100_000_000,
51+
subsamples={
52+
"top_activations": {"proportion": 1.0, "n_samples": 20},
53+
"subsampling_80%": {"proportion": 0.8, "n_samples": 10},
54+
"subsampling_60%": {"proportion": 0.6, "n_samples": 10},
55+
"subsampling_40%": {"proportion": 0.4, "n_samples": 10},
56+
"non_activating": {"proportion": 0.3, "n_samples": 20, "max_length": 50},
57+
},
58+
),
59+
mongo=MongoDBConfig(),
60+
device_type="cuda",
61+
)
62+
)

0 commit comments

Comments
 (0)