Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
178 commits
Select commit Hold shift + click to select a range
b2e6ae0
Add muP fields, auto-update model cfg
daviswer Jul 17, 2024
4a02c82
Add mup scaling to fsdp init params
daviswer Jul 18, 2024
22c54a6
Only set mup cfg if >0
daviswer Jul 18, 2024
af52614
1d init mup
daviswer Jul 19, 2024
57ed6f9
Attempt mup lrs
daviswer Jul 19, 2024
372e1d2
cleanup, typofix
daviswer Jul 19, 2024
c0d1d1f
diag print
daviswer Jul 19, 2024
2017a98
Non double list comp
daviswer Jul 19, 2024
9a77a2b
diag print
daviswer Jul 19, 2024
6c01a0b
Stop named params
daviswer Jul 19, 2024
101652b
List sum
daviswer Jul 19, 2024
49341e1
diag print
daviswer Jul 19, 2024
58c1662
diag print
daviswer Jul 19, 2024
a14f57e
diag print
daviswer Jul 19, 2024
5c8d8c4
diag print
daviswer Jul 19, 2024
d0e4888
diag print
daviswer Jul 19, 2024
e9701a1
Iterate over submodules explicitly
daviswer Jul 19, 2024
0c46c3a
linear submods only
daviswer Jul 19, 2024
58ce680
diag print
daviswer Jul 19, 2024
39c5832
diag print
daviswer Jul 19, 2024
476dca5
Use orig params
daviswer Jul 19, 2024
a11abf7
Remove default lr arg
daviswer Jul 19, 2024
f2c5590
Enlist param groups
daviswer Jul 19, 2024
63a834a
divide by mup scales
daviswer Jul 19, 2024
5887896
Remove tele configs
daviswer Jul 22, 2024
4dd3998
Don't change Llama2 small configs
daviswer Jul 22, 2024
1491706
linting
daviswer Jul 22, 2024
34424a5
Merge branch 'main' into mup-beta
daviswer Oct 9, 2024
37f0fa3
mup param consolidation
daviswer Oct 9, 2024
2b2927a
Policy arg rectify
daviswer Oct 9, 2024
cc146f8
LR reporting correction
daviswer Oct 9, 2024
1eb7e47
LR checks
daviswer Oct 9, 2024
df14343
Checks passed
daviswer Oct 9, 2024
135e911
Enable unfused glu
daviswer Oct 9, 2024
db1d4f2
Fix init fn
daviswer Oct 9, 2024
1226260
orig params off
daviswer Oct 9, 2024
e538862
Set lr div in both places
daviswer Oct 9, 2024
c8fe19d
Re on mup lr
daviswer Oct 9, 2024
c44db52
Orig param back on
daviswer Oct 9, 2024
42827a2
Remove prints, implement search
daviswer Oct 9, 2024
59f5e0b
Add mup model cfgs
daviswer Oct 9, 2024
6e803b8
Fix syntax and report spacing
daviswer Oct 9, 2024
733fb91
Fix val reporting
daviswer Oct 9, 2024
f184b3f
Fix reporting again
daviswer Oct 9, 2024
a34204c
Fix earlystop check
daviswer Oct 9, 2024
d49c0c6
Fix early stop check
daviswer Oct 10, 2024
915a6ea
Fix early stop check
daviswer Oct 10, 2024
e7d392b
Flip up/down order
daviswer Oct 10, 2024
cf95c16
Remove various warnings and prints
daviswer Oct 10, 2024
7ea2ade
More suppression
daviswer Oct 10, 2024
cacde5f
Remove ckpdataset since we're not ckping anyways
daviswer Oct 10, 2024
76fc955
Reset dynamo cache between runs
daviswer Oct 10, 2024
c18c878
Clear cache between runs, reorder params
daviswer Oct 10, 2024
9473afd
skew based attn/ffn init
daviswer Oct 11, 2024
4642aac
shorten skew name
daviswer Oct 11, 2024
1a00cd5
Nelder Mead
daviswer Oct 14, 2024
ee9ca22
Nelder Mead pt2
daviswer Oct 14, 2024
6323df1
Nelder Mead pt3
daviswer Oct 14, 2024
896c37b
Nelder Mead pt4
daviswer Oct 14, 2024
471142b
Nelder Mead pt5
daviswer Oct 14, 2024
eba5453
Nelder Mead pt6
daviswer Oct 14, 2024
d2498dd
Nelder Mead pt7
daviswer Oct 14, 2024
e376f88
Nelder Mead pt8
daviswer Oct 14, 2024
488d956
Nelder Mead pt9
daviswer Oct 14, 2024
ca8350d
Final reporting fix, memory clear
daviswer Oct 14, 2024
b6fe56a
Reported sorted simplex
daviswer Oct 14, 2024
c081f99
Delta reporting
daviswer Oct 14, 2024
c09a163
Report correct prior for inside contraction
daviswer Oct 14, 2024
d6290cd
Cleanup reporting, memory
daviswer Oct 14, 2024
2de6aff
Mem cleanup attempts
daviswer Oct 14, 2024
64a599b
Snapshotting after 1
daviswer Oct 14, 2024
62ca17d
Snapshotting after 1 pt2
daviswer Oct 14, 2024
b02e9cb
Create/destroy process group between runs
daviswer Oct 14, 2024
286bc93
diag print
daviswer Oct 14, 2024
0b1fd52
Slow it down, give time for construction
daviswer Oct 14, 2024
41645ba
Forget process grouping, just explicit gc
daviswer Oct 14, 2024
e25c5bf
Forget process grouping, just explicit gc pt2
daviswer Oct 14, 2024
d4ba5f2
Forget deletion, try skipping stuff
daviswer Oct 14, 2024
1ade500
Decrease train steps, vanilla otherwise
daviswer Oct 14, 2024
5cc41b3
Early return single step
daviswer Oct 14, 2024
8f2280d
Add reporting
daviswer Oct 14, 2024
f4961ac
After forward
daviswer Oct 14, 2024
fefe5a6
a whole lotta single steps (make a journey)
daviswer Oct 14, 2024
72345dd
Single stepping, full
daviswer Oct 14, 2024
74142e3
Delete stuff at source
daviswer Oct 14, 2024
f006cfa
Orig params off
daviswer Oct 14, 2024
45483c2
Orig params off pt2
daviswer Oct 14, 2024
49d39d9
Orig params off pt3
daviswer Oct 14, 2024
b11f66a
Restore proper param grouping
daviswer Oct 14, 2024
46b0921
Remove rope init call
daviswer Oct 14, 2024
ab2303c
Revert rope off
daviswer Oct 14, 2024
6012e8b
Move to cpu before del
daviswer Oct 15, 2024
b3141aa
Does destroy pg even help
daviswer Oct 15, 2024
961ac0e
create/destroy each pg, with verbosity
daviswer Oct 15, 2024
5cf312d
Add a delay
daviswer Oct 15, 2024
4e96f93
Barrier
daviswer Oct 15, 2024
8963e4e
Add a delay
daviswer Oct 15, 2024
519990f
Just play with pg stuff for now
daviswer Oct 15, 2024
4f45c71
longer wait, quit flooding
daviswer Oct 15, 2024
2b3e58e
Shorter wait, diag print
daviswer Oct 15, 2024
d329779
Staggered calls, collective tests
daviswer Oct 15, 2024
d67f824
new group, keep global
daviswer Oct 15, 2024
03a4e44
Fix inside contraction test
daviswer Oct 15, 2024
fdd4d44
Reverse init for test
daviswer Oct 15, 2024
a5ab9cc
Reduced report only at end
daviswer Oct 15, 2024
af85d06
Centering offset
daviswer Oct 15, 2024
00e8825
Fix logging/reporting
daviswer Oct 15, 2024
0453796
Fix initial simplexing and reporting
daviswer Oct 15, 2024
0fab0a9
Wider search radius
daviswer Oct 15, 2024
90a3cc2
Diag print
daviswer Oct 15, 2024
fe99e04
Diag print 2
daviswer Oct 15, 2024
36db119
pull initial vals from llama cfg
daviswer Oct 15, 2024
edd5108
pull initial vals from llama cfg
daviswer Oct 15, 2024
d2ad704
pull initial vals from llama cfg
daviswer Oct 15, 2024
6fbcd49
diag print off
daviswer Oct 15, 2024
a8b81ca
Flip simplex test
daviswer Oct 16, 2024
27ce2e5
Adaptive Nelder-Mead
daviswer Oct 16, 2024
c0f3434
More verbosity, sign flip to orig
daviswer Oct 16, 2024
cd3c714
Set up hyperparam sweep for mup impl checking
daviswer Oct 17, 2024
0bfb970
Resweep
daviswer Oct 17, 2024
ad2697e
Regular simplex hedron
daviswer Oct 17, 2024
ec819e0
2d cat (meow)
daviswer Oct 17, 2024
dae9aa8
Fix early reporting
daviswer Oct 17, 2024
1022070
Flip initial simplex
daviswer Oct 17, 2024
9790891
Flip back
daviswer Oct 17, 2024
d6711f1
nan handling
daviswer Oct 17, 2024
5b644fc
Flip back back
daviswer Oct 17, 2024
7618d63
Flip back back back (now you're just lazy)
daviswer Oct 17, 2024
a586790
Early early stop
daviswer Oct 17, 2024
1fb181d
diag print
daviswer Oct 18, 2024
0329ddd
Basic test
daviswer Oct 18, 2024
a4d3e64
Basic test
daviswer Oct 18, 2024
99b1d80
Tweaking set mups
daviswer Oct 18, 2024
27716da
Tweaking set mups
daviswer Oct 18, 2024
f6388c8
Tweaking set mups
daviswer Oct 18, 2024
07d9e8e
Tweaking set mups
daviswer Oct 18, 2024
1b66f56
Tweaking set mups
daviswer Oct 18, 2024
d28163c
Tweaking set mups
daviswer Oct 18, 2024
b0bc6bf
Diag print off, fix sweep also
daviswer Oct 18, 2024
6943b21
Diag print off, fix sweep also
daviswer Oct 18, 2024
d210661
Quit temp test
daviswer Oct 18, 2024
7bb86f9
Simplex flip
daviswer Oct 18, 2024
e70e05c
Read final values from cfg (llama cfg has no lr)
daviswer Oct 21, 2024
4b2c731
SINE FLEP!!!
daviswer Oct 21, 2024
89b7343
Prep for 70b run
daviswer Oct 26, 2024
a6578a9
Merge branch 'mup-search' of github.com:daviswer/fms-fsdp into mup-se…
daviswer Oct 26, 2024
d0c688a
Random sign flip for init simplex
daviswer Oct 27, 2024
9c7209b
Custom init
daviswer Oct 27, 2024
d7181be
Reset compile cache each run
daviswer Oct 28, 2024
5822a33
No flips, just rely on diff seeds
daviswer Oct 28, 2024
8f7dc8a
upper lim 128k on doc len for parquet
daviswer Oct 29, 2024
1387529
Multiple hardcoded col fields
daviswer Oct 29, 2024
671d970
Patch 1 (#8)
lchu6 Oct 30, 2024
b338f98
Blacking and file size splitting (rather than assume equal)
daviswer Oct 30, 2024
927ac1b
Fix parquet edge cases by reading cols/rows from metadata
daviswer Nov 1, 2024
ebbd77e
Verbose detect countfile
daviswer Nov 5, 2024
336ceeb
Verbose no detect countfile
daviswer Nov 5, 2024
08251b5
Fix quotes
daviswer Nov 5, 2024
e93f883
Print all ranks
daviswer Nov 5, 2024
10336dd
Mega verbose
daviswer Nov 5, 2024
117af56
Attempt to read filesizes from metadata
daviswer Nov 5, 2024
740005f
Full shard path metadata lookup
daviswer Nov 5, 2024
b1cc8e2
Partial path lookup
daviswer Nov 5, 2024
21f7ccd
Fix metadata path reading w/o prefix
daviswer Nov 5, 2024
a2a1df7
Dif prints, fix empty prefix error for doc count reader too
daviswer Nov 5, 2024
62b631c
Cumulative stat print, remove diag prints
daviswer Nov 6, 2024
d6c7817
Fix string formatting
daviswer Nov 6, 2024
2d8170f
Further fix formatting, print once
daviswer Nov 6, 2024
64b75ef
Add checkpointing
daviswer Nov 8, 2024
e989197
Temp disable countfilesizes for testing
daviswer Nov 8, 2024
df4715a
Fix off by one reload error, renable countfilesize
daviswer Nov 8, 2024
742e509
Fix off by one reload error, renable countfilesize
daviswer Nov 8, 2024
170555e
Simple train script
daviswer Nov 15, 2024
49b4ff9
Diag print
daviswer Nov 15, 2024
660de0a
Verbose loop
daviswer Nov 15, 2024
73f91fa
Use vals as vals, not multipliers
daviswer Nov 15, 2024
012c0ed
Use vals as vals, not multipliers
daviswer Nov 15, 2024
0e53783
singleton drop token handling
daviswer Nov 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions fms_fsdp/config/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,15 @@ class train_config:
# compile
use_torch_compile: bool = True

# muP scale params
mup_emb_scale: float = 0
mup_head_scale: float = 0
mup_a_f_skew: float = 0
mup_attn_temp: float = 0
mup_lr_dscale: float = 0
mup_explore_range: float = 5.0
mup_search_steps: int = 10

# speculator training
tp_size: int = 8
model_arch: str = "embedllama"
Expand Down
25 changes: 16 additions & 9 deletions fms_fsdp/policies/param_init.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
import torch
from fms.modules.attention import MultiHeadAttention
from fms.modules.attention import QKV, MultiHeadAttention
from fms.modules.embedding import WordEmbedding
from fms.modules.feedforward import GatedLinearUnit
from fms.modules.layernorm import LayerNormParameterized


# for details, read https://github.com/foundation-model-stack/fms-fsdp/issues/64
def param_init_function(module):
if (
isinstance(module, MultiHeadAttention)
or isinstance(module, WordEmbedding)
or isinstance(module, GatedLinearUnit)
or isinstance(module, LayerNormParameterized)
):
def param_init_function(module, cfg):
scales = {
MultiHeadAttention: cfg.mup_a_f_skew**0.5,
QKV: cfg.mup_a_f_skew**0.5,
GatedLinearUnit: cfg.mup_a_f_skew**-0.5,
WordEmbedding: (cfg.mup_emb_scale, cfg.mup_head_scale),
LayerNormParameterized: 1,
}
scale_keys = list(scales.keys())
scale_vals = list(scales.values())
type_id = [isinstance(module, x) for x in scale_keys]
is_resettable = sum(type_id)
if is_resettable:
module_type_id = type_id.index(True)
module.to_empty(device=torch.cuda.current_device())
with torch.no_grad():
module.reset_parameters()
module.reset_parameters(scale=scale_vals[module_type_id])
76 changes: 74 additions & 2 deletions fms_fsdp/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ def get_model_config(model_variant):
emb_dim=2048,
nheads=16,
nlayers=24,
hidden_grow_factor=3,
kvheads=4,
)
elif model_variant == "llama3_8b":
llama_config = LLaMAConfig(
Expand Down Expand Up @@ -128,6 +126,39 @@ def get_model_config(model_variant):
max_expected_seq_len=4096,
rope_theta=500000.0,
)
elif model_variant == "llama3_3.2b_4k_mup_tiny":
llama_config = LLaMAConfig(
src_vocab_size=128256,
emb_dim=384,
nheads=3,
kvheads=1,
nlayers=24,
hidden_grow_factor=8 / 3,
max_expected_seq_len=4096,
rope_theta=500000.0,
)
elif model_variant == "llama3_3.2b_4k_mup_small":
llama_config = LLaMAConfig(
src_vocab_size=128256,
emb_dim=768,
nheads=6,
kvheads=2,
nlayers=24,
hidden_grow_factor=8 / 3,
max_expected_seq_len=4096,
rope_theta=500000.0,
)
elif model_variant == "llama3_3.2b_4k_mup_medium":
llama_config = LLaMAConfig(
src_vocab_size=128256,
emb_dim=1536,
nheads=12,
kvheads=4,
nlayers=24,
hidden_grow_factor=8 / 3,
max_expected_seq_len=4096,
rope_theta=500000.0,
)
elif model_variant == "llama3_70b":
llama_config = LLaMAConfig(
src_vocab_size=128256,
Expand All @@ -150,6 +181,39 @@ def get_model_config(model_variant):
max_expected_seq_len=4096,
rope_theta=500000.0,
)
elif model_variant == "llama3_70b_4k_medium":
llama_config = LLaMAConfig(
src_vocab_size=128256,
emb_dim=3172,
nheads=24,
kvheads=3,
nlayers=80,
hidden_grow_factor=3.5,
max_expected_seq_len=4096,
rope_theta=500000.0,
)
elif model_variant == "llama3_70b_4k_small":
llama_config = LLaMAConfig(
src_vocab_size=128256,
emb_dim=2048,
nheads=16,
kvheads=2,
nlayers=80,
hidden_grow_factor=3.5,
max_expected_seq_len=4096,
rope_theta=500000.0,
)
elif model_variant == "llama3_70b_4k_tiny":
llama_config = LLaMAConfig(
src_vocab_size=128256,
emb_dim=1024,
nheads=8,
kvheads=1,
nlayers=80,
hidden_grow_factor=3.5,
max_expected_seq_len=4096,
rope_theta=500000.0,
)
elif model_variant == "llama3_194m_4k":
llama_config = LLaMAConfig(
src_vocab_size=128256,
Expand All @@ -163,3 +227,11 @@ def get_model_config(model_variant):
raise ValueError(f"model variant {model_variant} not supported.")

return llama_config


def set_mup_from_cfg(job_cfg, model_cfg):
fields = {k: v for k, v in vars(job_cfg).items() if "mup" in k and v > 0}
for f in fields:
if hasattr(model_cfg, f):
setattr(model_cfg, f, fields[f])
return model_cfg
16 changes: 8 additions & 8 deletions fms_fsdp/utils/dataloader_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def causal_lm(data_seq, prompt_len=1):
Perform causal language modeling by right-shifting the input sequence.
Sets first prompt_len tokens to be ignored by the loss.
"""
data_seq = torch.tensor(data_seq, dtype=torch.int)
data_seq = data_seq.int()
t = data_seq.clone()[1:]
data_seq = data_seq[:-1]
t[:prompt_len] = -100
Expand Down Expand Up @@ -132,13 +132,13 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]):
data = PreprocessDataset(data, p)

# Enable auto-saving
data = CheckpointDataset(
data,
cfg.ckpt_load_path if cfg.resuming_dataset else cfg.ckpt_save_path,
cfg.checkpoint_interval,
cfg.batch_size,
cfg.ckpt_save_path,
)
# data = CheckpointDataset(
# data,
# cfg.ckpt_load_path if cfg.resuming_dataset else cfg.ckpt_save_path,
# cfg.checkpoint_interval,
# cfg.batch_size,
# cfg.ckpt_save_path,
# )
return torch.utils.data.DataLoader(
data, num_workers=cfg.num_workers, batch_size=cfg.batch_size
)
Expand Down
122 changes: 70 additions & 52 deletions fms_fsdp/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,11 +357,10 @@ def length(self, path: str):

def get(self, reader: pa.RecordBatchFileReader, index: int, drop_tokens: Set):
doc = reader.get_batch(index)[self.col_name]
if len(doc) > 0:
if doc[0].as_py() in drop_tokens:
doc = doc.slice(1, len(doc) - 1)
if doc[-1].as_py() in drop_tokens:
doc = doc.slice(0, len(doc) - 1)
if len(doc) > 0 and doc[0].as_py() in drop_tokens:
doc = doc.slice(1, len(doc) - 1)
if len(doc) > 0 and doc[-1].as_py() in drop_tokens:
doc = doc.slice(0, len(doc) - 1)
return doc

def slice(self, doc: pa.UInt32Array, index: int, n_pull: int) -> List:
Expand All @@ -384,18 +383,24 @@ def is_legal(self, filepath: str):
return "parquet" in os.path.splitext(filepath)[1]

def open(self, path: str):
return pq.read_pandas(path, columns=[self.col_name])[self.col_name]
colnames = pq.read_metadata(path).schema.names
legal_fields = ["text", "content", "contents"]
overlap = set(legal_fields).intersection(set(colnames))
assert (
len(overlap) == 1
), f"{len(overlap)} shared column names detected, need 1 ({overlap})"
name = overlap.pop()
return pq.read_pandas(path, columns=[name], partitioning=None)[name]

def length(self, path: str):
return pq.read_pandas(path, columns=[]).num_rows
return pq.read_metadata(path).num_rows

def get(self, reader, index: int, drop_tokens: Set):
doc = self.tokenizer(str(reader[index]))["input_ids"]
if len(doc) > 0:
if doc[0] in drop_tokens:
doc = doc[1:]
if doc[-1] in drop_tokens:
doc = doc[:-1]
doc = self.tokenizer(str(reader[index])[:128_000])["input_ids"]
if len(doc) > 0 and doc[0] in drop_tokens:
doc = doc[1:]
if len(doc) > 0 and doc[-1] in drop_tokens: # Recheck len for edge case where doc=[eos]
doc = doc[:-1]
return doc

def slice(self, doc: List, index: int, n_pull: int) -> List:
Expand Down Expand Up @@ -872,73 +877,86 @@ def setup(self):
if self.filehandler.is_legal(os.path.join(root, name))
]
shards.sort() # Ensure consistent sharding across machines
start_frag = (self.rank * self.worldsize * len(shards)) // self.worldsize
end_frag = (
(self.rank + 1) * self.worldsize * len(shards)
) // self.worldsize
shardfrags = [
(shards[i // self.worldsize], i % self.worldsize)
for i in range(start_frag, end_frag)
]

# Assemble length of each owned shard file

# Find metadata file
countfiles = []
if os.path.exists(os.path.join(pardir, "meta")):
countfiles = [
x
for x in os.listdir(os.path.join(pardir, "meta"))
if "counts" in x and "csv" in x
]
doc_counts = {}
if len(countfiles) > 0:
# Count file exists, use it
countpath = os.path.join(pardir, "meta", countfiles[0])
else:
countpath = ""

# Use shard file sizes to perform partitioning
# Create shardlist of form shardid -> [start%, end%]
if len(countfiles) > 0:
sizes = {}
with open(countpath, "r") as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
fullpath = row["dataset/filename"]
prefix = fullpath.find(dataset + "/")
if prefix >= 0:
key = fullpath[prefix + len(dataset) + 1 :]
sizes[key] = int(row["size"])
shard_sizes = [sizes[shard] for shard in shards]
else:
shard_sizes = [
os.path.getsize(os.path.join(datapath, shard)) for shard in shards
]
shard_sizes = [s / sum(shard_sizes) for s in shard_sizes]
start = self.rank / self.worldsize
end = (self.rank + 1) / self.worldsize
shardset = {}
tally = 0
for i in range(len(shards)):
if tally <= end and tally + shard_sizes[i] >= start:
shardset[shards[i]] = [
min(max((start - tally) / shard_sizes[i], 0), 1),
min(max((end - tally) / shard_sizes[i], 0), 1),
]
tally += shard_sizes[i]

# Assemble length of each owned shard file
doc_counts = {}
if len(countfiles) > 0:
# Count file exists, use it
with open(countpath, "r") as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
fullpath = row["dataset/filename"]
prefix = fullpath.find("/" + dataset) + 1
if prefix > 0:
prefix = fullpath.find(dataset + "/")
if prefix >= 0:
key = fullpath[prefix + len(dataset) + 1 :]
doc_counts[key] = int(row["documents"])
else:
# Count file does not exist, touch every owned file for length
unique_shardfiles = set(shard for shard, frag in shardfrags)
# unique_shardfiles = set(shard for shard, frag in shardfrags)
doc_counts = {
shard: self.filehandler.length(os.path.join(datapath, shard))
for shard in unique_shardfiles
for shard in shardset
}

# Read shardfrags, assemble doc list for each file shard (aggregating over fragments):
ndocs = -1
docset = {} # shardid -> (min docid, max docid)
for i, (shard, frag) in enumerate(shardfrags):
ndocs = doc_counts[shard]
doc_start = (ndocs * frag) // self.worldsize
doc_end = (
ndocs * frag + ndocs
) // self.worldsize - 1 # Inclusive upper bound
if shard not in docset:
docset[shard] = [doc_start, doc_end]
min_d, max_d = docset[shard]
if doc_start < min_d:
docset[shard][0] = doc_start
if doc_end > max_d:
docset[shard][1] = doc_end

# Add shard entries to self.docset
# Assemble doc list for each file shard
# Create docset of form [shardid, min docid, max docid]
doccount = 0
for shardid in docset:
min_d = docset[shardid][0]
max_d = docset[shardid][1]
self.docset.append((shardid, min_d, max_d))
doccount += max_d - min_d + 1
for shard in shardset:
ndocs = doc_counts[shard]
doc_start = round(ndocs * shardset[shard][0])
doc_end = round(ndocs * shardset[shard][1]) - 1 # inclusive upper bound
if doc_end >= doc_start:
self.docset.append([shard, doc_start, doc_end])
doccount += doc_end - doc_start + 1
self._len = doccount

if self.verbose:
logging.info(
f" Worker {self.rank} ingested {len(shardfrags)} shard fragments from {dataset}"
f" Worker {self.rank} ingested {len(self.docset)} shards from {dataset}"
)

# Shuffle shard files - guaranteed inconsistent across workers
Expand Down
Loading