Skip to content

Commit 0067eac

Browse files
committed
Merge branch 'main' into toloka-armenia
Signed-off-by: Alexan <[email protected]>
2 parents 859c31f + 371c2b5 commit 0067eac

File tree

8 files changed

+449
-22
lines changed

8 files changed

+449
-22
lines changed

docs/gen_docs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,5 @@ def gen_docs():
5353
with open(destination_path, "wt", encoding="utf-8") as fout:
5454
fout.write(docs + link)
5555

56-
5756
if __name__ == '__main__':
5857
gen_docs()

docs/src/sdp/api.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,9 @@ ASR-based processors
212212
``text_key`` (defaults to "text") and ``pred_text_key`` (defaults to "text_pred")
213213
to control which fields contain transcription and ASR model predictions.
214214

215+
.. autodata:: sdp.utils.BootstrapProcessor
216+
:annotation:
217+
215218
Data modifications
216219
''''''''''''''''''
217220

docs/src/sdp/existing_configs.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,4 +351,4 @@ Tarteel AI's EveryAyah
351351
.. toctree::
352352
:hidden:
353353

354-
config-docs/arabic/everyayah/config
354+
config-docs/arabic/everyayah/config

sdp/processors/base_processor.py

Lines changed: 65 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,12 @@ def test(self):
8282
There are not tests by default.
8383
"""
8484

85-
class DaskParallelProcessor(BaseProcessor):
85+
class BaseParallelProcessor(BaseProcessor):
8686
"""
8787
Processor class which allows operations on each entry to be parallelized using Dask.
8888
8989
Parallelization is done by distributing the workload using Dask bags inside
90-
the :meth:`process` method.
90+
the :meth:`process` method.
9191
9292
Actual processing should be defined on a per-example basis inside the
9393
:meth:`process_dataset_entry` method.
@@ -127,16 +127,40 @@ def process(self):
127127
import multiprocessing
128128
import psutil
129129

130+
os.environ.setdefault("PATH", os.defpath)
131+
default_path = os.environ["PATH"]
130132

131133
self.prepare()
132134
os.makedirs(os.path.dirname(self.output_manifest_file), exist_ok=True)
133135
metrics = []
134136

135-
# total line count.
136-
with open(self.input_manifest_file, "rb") as f:
137-
total_entries = sum(1 for _ in f)
138-
139-
# check that we did not set different amount of cpu.
137+
# If no input manifest file is provided, use the read_manifest() method to get entries.
138+
if self.input_manifest_file is None:
139+
manifest_entries = list(self.read_manifest())
140+
total_entries = len(manifest_entries)
141+
if total_entries == 0:
142+
logger.info("No input manifest entries found; using empty bag.")
143+
bag = db.from_sequence([])
144+
else:
145+
bag = db.from_sequence(manifest_entries)
146+
else:
147+
# Check if input manifest file exists and is non-empty.
148+
if (not os.path.exists(self.input_manifest_file) or os.path.getsize(self.input_manifest_file) == 0):
149+
logger.info("Input manifest file not found or empty; using empty bag.")
150+
total_entries = 0
151+
bag = db.from_sequence([])
152+
else:
153+
# Compute total line count.
154+
with open(self.input_manifest_file, "rb") as f:
155+
total_entries = sum(1 for _ in f)
156+
if total_entries == 0:
157+
logger.info("Input manifest file is empty; using empty bag.")
158+
bag = db.from_sequence([])
159+
else:
160+
bag = db.read_text(self.input_manifest_file, encoding="utf-8", blocksize="8MB")
161+
bag = bag.map(json.loads)
162+
163+
# Set up worker resources.
140164
num_cpus = multiprocessing.cpu_count() if self.max_workers == -1 else self.max_workers
141165
total_memory = psutil.virtual_memory().total
142166
mem_per_worker = total_memory // num_cpus
@@ -154,11 +178,14 @@ def process(self):
154178
"distributed.comm.timeouts.connect": "30s",
155179
"distributed.comm.timeouts.tcp": "30s"
156180
}):
157-
client = Client(n_workers=num_cpus, processes=True, threads_per_worker=2, memory_limit=memory_limit)
181+
client = Client(
182+
n_workers=num_cpus,
183+
processes=True,
184+
threads_per_worker=2,
185+
memory_limit=memory_limit,
186+
env={"PATH": default_path}
187+
)
158188
try:
159-
bag = db.read_text(self.input_manifest_file, encoding="utf-8", blocksize="8MB")
160-
bag = bag.map(json.loads)
161-
162189
def process_partition(partition):
163190
results = []
164191
for data_entry in partition:
@@ -175,13 +202,11 @@ def process_partition(partition):
175202
delayed_results = bag.to_delayed()
176203
futures = client.compute(delayed_results)
177204

178-
# Open output file before starting to process futures.
179205
with open(self.output_manifest_file, "wt", encoding="utf8") as fout, \
180206
tqdm(total=total_entries, desc="Processing entries", unit="entry",
181207
mininterval=0.5, miniters=10) as pbar:
182208
for future in as_completed(futures):
183209
partition_result = future.result()
184-
# Write results immediately (as soon as available)
185210
for data_entry in partition_result:
186211
metrics.append(data_entry.metrics)
187212
if data_entry.data is None:
@@ -218,10 +243,14 @@ def _chunk_manifest(self):
218243
yield manifest_chunk
219244

220245
def read_manifest(self):
221-
"""Read entries from the input manifest file."""
222-
if not self.input_manifest_file:
223-
raise ValueError("Input manifest file is not specified.")
246+
"""Read entries from the input manifest file.
224247
248+
Returns an iterator over JSON objects. If no input manifest file is set,
249+
returns an empty iterator.
250+
"""
251+
if not self.input_manifest_file:
252+
logger.info("No input manifest file specified; returning empty manifest iterator.")
253+
return iter([])
225254
with open(self.input_manifest_file, "r", encoding="utf-8") as fin:
226255
for line in fin:
227256
yield json.loads(line)
@@ -249,9 +278,28 @@ def finalize(self, metrics: List[Any]):
249278
)
250279
logger.info("Processor completed in (seconds): %.2f", time.time() - self.start_time)
251280

281+
def test(self):
282+
"""Applies processing to each test case and raises an error if the output does not match expected output."""
283+
for test_case in self.test_cases:
284+
input_data = test_case["input"].copy() if isinstance(test_case["input"], dict) else test_case["input"]
285+
generated_outputs = self.process_dataset_entry(input_data)
286+
expected_outputs = (
287+
[test_case["output"]] if not isinstance(test_case["output"], list)
288+
else test_case["output"]
289+
)
290+
for generated_output, expected_output in zip(generated_outputs, expected_outputs):
291+
generated_data = generated_output.data if hasattr(generated_output, "data") else generated_output
292+
if generated_data != expected_output:
293+
raise RuntimeError(
294+
"Runtime test failed.\n"
295+
f"Test input: {test_case['input']}\n"
296+
f"Generated output: {generated_data}\n"
297+
f"Expected output: {expected_output}"
298+
)
252299

253300

254-
class BaseParallelProcessor(BaseProcessor):
301+
302+
class LegacyParallelProcessor(BaseProcessor):
255303
"""Processor class which allows operations on each utterance to be parallelized.
256304
257305
Parallelization is done using ``tqdm.contrib.concurrent.process_map`` inside

sdp/processors/modify_manifest/common.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
BaseParallelProcessor,
2525
BaseProcessor,
2626
DataEntry,
27-
DaskParallelProcessor,
27+
LegacyParallelProcessor,
2828
)
2929
from sdp.utils.common import load_manifest
3030

@@ -99,9 +99,9 @@ def process_dataset_entry(self, data_entry: Dict):
9999
return [DataEntry(data=data_entry)]
100100

101101

102-
class AddConstantFields(DaskParallelProcessor):
102+
class AddConstantFields(BaseParallelProcessor):
103103
"""
104-
This processor adds constant fields to all manifest entries using DaskParallelProcessor.
104+
This processor adds constant fields to all manifest entries using Dask BaseParallelProcessor.
105105
It is useful when you want to attach fixed information (e.g., a language label or metadata)
106106
to each entry for downstream tasks such as language identification model training.
107107

sdp/utils/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
16+
from sdp.utils.bootstrap_estimates import BootstrapProcessor

0 commit comments

Comments
 (0)