Skip to content

Commit b9446c4

Browse files
committed
fix(util): move is_primary_rank to distributed utils to avoid circular import
1 parent e8890a4 commit b9446c4

File tree

10 files changed

+20
-27
lines changed

10 files changed

+20
-27
lines changed

src/lm_saes/abstract_sae.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,10 @@
3535
from lm_saes.backend.language_model import LanguageModelConfig
3636
from lm_saes.config import BaseModelConfig
3737
from lm_saes.utils.auto import PretrainedSAEType, auto_infer_pretrained_sae_type
38-
from lm_saes.utils.distributed import DimMap, distributed_topk, item, mesh_dim_size
38+
from lm_saes.utils.distributed import DimMap, distributed_topk, is_primary_rank, item, mesh_dim_size
3939
from lm_saes.utils.distributed.utils import execute_and_broadcast
4040
from lm_saes.utils.logging import get_distributed_logger
4141
from lm_saes.utils.math import topk
42-
from lm_saes.utils.misc import (
43-
is_primary_rank,
44-
)
4542
from lm_saes.utils.tensor_specs import TensorSpecs, apply_token_mask
4643
from lm_saes.utils.timer import timer
4744

src/lm_saes/analysis/feature_analyzer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,8 @@
1616
from lm_saes.crosscoder import CrossCoder
1717
from lm_saes.sparse_dictionary import SparseDictionary
1818
from lm_saes.utils.discrete import KeyedDiscreteMapper
19-
from lm_saes.utils.distributed import DimMap, masked_fill, to_local
19+
from lm_saes.utils.distributed import DimMap, is_primary_rank, masked_fill, to_local
2020
from lm_saes.utils.distributed.ops import item
21-
from lm_saes.utils.misc import is_primary_rank
2221
from lm_saes.utils.tensor_dict import concat_dict_of_tensor, sort_dict_of_tensor
2322

2423
from .post_analysis import PostAnalysisProcessor, get_post_analysis_processor

src/lm_saes/analysis/post_analysis/lorsa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@
3434
from lm_saes.lorsa import LowRankSparseAttention
3535
from lm_saes.sparse_dictionary import SparseDictionary
3636
from lm_saes.utils.discrete import KeyedDiscreteMapper
37+
from lm_saes.utils.distributed import is_primary_rank
3738
from lm_saes.utils.distributed.ops import item
3839
from lm_saes.utils.logging import get_distributed_logger
39-
from lm_saes.utils.misc import is_primary_rank
4040

4141
from .base import PostAnalysisProcessor, register_post_analysis_processor
4242

src/lm_saes/runners/topk_to_jumprelu_conversion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
from lm_saes.database import MongoClient, MongoDBConfig
1515
from lm_saes.resource_loaders import load_dataset, load_model
1616
from lm_saes.sparse_dictionary import SparseDictionary
17+
from lm_saes.utils.distributed import is_primary_rank
1718
from lm_saes.utils.logging import get_distributed_logger, setup_logging
18-
from lm_saes.utils.misc import is_primary_rank
1919
from lm_saes.utils.topk_to_jumprelu_conversion import topk_to_jumprelu_conversion
2020

2121
from .utils import PretrainedSAE, load_config

src/lm_saes/runners/train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@
2222
from lm_saes.resource_loaders import load_dataset, load_model
2323
from lm_saes.sparse_dictionary import SparseDictionary, SparseDictionaryConfig
2424
from lm_saes.trainer import Trainer, TrainerConfig, WandbConfig
25-
from lm_saes.utils.distributed import mesh_rank
25+
from lm_saes.utils.distributed import is_primary_rank, mesh_rank
2626
from lm_saes.utils.logging import get_distributed_logger, setup_logging
27-
from lm_saes.utils.misc import is_primary_rank
2827

2928
from .utils import PretrainedSAE, load_config
3029

src/lm_saes/sparse_dictionary.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,10 @@
3535
from lm_saes.backend.language_model import LanguageModelConfig
3636
from lm_saes.config import BaseModelConfig
3737
from lm_saes.utils.auto import PretrainedSAEType, auto_infer_pretrained_sae_type
38-
from lm_saes.utils.distributed import DimMap, distributed_topk, item, mesh_dim_size
38+
from lm_saes.utils.distributed import DimMap, distributed_topk, is_primary_rank, item, mesh_dim_size
3939
from lm_saes.utils.distributed.utils import execute_and_broadcast
4040
from lm_saes.utils.logging import get_distributed_logger
4141
from lm_saes.utils.math import topk
42-
from lm_saes.utils.misc import (
43-
is_primary_rank,
44-
)
4542
from lm_saes.utils.tensor_specs import TensorSpecs, apply_token_mask
4643
from lm_saes.utils.timer import timer
4744

src/lm_saes/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,12 @@
3535
)
3636
from lm_saes.optim import SparseAdam, clip_grad_norm, get_scheduler
3737
from lm_saes.sparse_dictionary import SparseDictionary
38+
from lm_saes.utils.distributed import is_primary_rank
3839
from lm_saes.utils.distributed.ops import item
3940
from lm_saes.utils.logging import get_distributed_logger, log_metrics
4041
from lm_saes.utils.misc import (
4142
convert_str_to_torch_dtype,
4243
convert_torch_dtype_to_str,
43-
is_primary_rank,
4444
)
4545
from lm_saes.utils.tensor_specs import apply_token_mask
4646
from lm_saes.utils.timer import timer

src/lm_saes/utils/distributed/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
all_gather_dict,
55
all_gather_list,
66
get_process_group,
7+
is_primary_rank,
78
mesh_dim_rank,
89
mesh_dim_size,
910
mesh_rank,
@@ -12,6 +13,7 @@
1213

1314
__all__ = [
1415
"DimMap",
16+
"is_primary_rank",
1517
"distributed_topk",
1618
"item",
1719
"masked_fill",

src/lm_saes/utils/distributed/utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,17 @@
77
from torch.distributed.device_mesh import DeviceMesh
88
from torch.distributed.tensor import Placement
99

10-
from lm_saes.utils.misc import is_primary_rank
10+
11+
def is_primary_rank(device_mesh: DeviceMesh | None, dim_name: str = "sweep") -> bool:
12+
"""Check if the current rank is the primary rank for the given mesh dimension."""
13+
if device_mesh is None:
14+
return True
15+
coord = device_mesh.get_coordinate()
16+
mesh_dim_names = device_mesh.mesh_dim_names
17+
if coord is None or mesh_dim_names is None:
18+
return False
19+
coord = [c for i, c in enumerate(coord) if dim_name not in mesh_dim_names or i != mesh_dim_names.index(dim_name)]
20+
return all(c == 0 for c in coord)
1121

1222

1323
def all_gather_dict(

src/lm_saes/utils/misc.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,6 @@ def is_master() -> bool:
1818
return not dist.is_initialized() or dist.get_rank() == 0
1919

2020

21-
def is_primary_rank(device_mesh: DeviceMesh | None, dim_name: str = "sweep") -> bool:
22-
if device_mesh is None:
23-
return True
24-
coord = device_mesh.get_coordinate()
25-
mesh_dim_names = device_mesh.mesh_dim_names
26-
if coord is None or mesh_dim_names is None:
27-
return False
28-
coord = [c for i, c in enumerate(coord) if dim_name not in mesh_dim_names or i != mesh_dim_names.index(dim_name)]
29-
return all(c == 0 for c in coord)
30-
31-
3221
def print_once(
3322
*values: object,
3423
sep: str | None = " ",

0 commit comments

Comments
 (0)