Skip to content

Commit 44ed516

Browse files
FromCSUZhoupre-commit-ci[bot]nickcom007
authored
feat: added byte and target token calculations for evaluation datasets, improved calculation logic for BPC/bPPL (#82)
* feat: added byte and target token calculations for evaluation datasets, improved calculation logic for BPC/bPPL * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix: removed unused numpy imports to optimize code * feat: optimize the calculation logic code and logging function of BPC and bPPL * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat: simplify the relevant calculation logic and log code in the verification process * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: add handling for loss calculation --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Nick <nickcom007@gmail.com>
1 parent 2d878d1 commit 44ed516

File tree

6 files changed

+731
-10
lines changed

6 files changed

+731
-10
lines changed

src/core/log_utils.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import math
2+
from loguru import logger
3+
import numbers
4+
5+
6+
def _log_summary_table(
7+
model_name_or_path,
8+
eval_loss,
9+
bpc_metrics,
10+
token_byte_ratio,
11+
total_target_tokens,
12+
total_bytes,
13+
vocab_size,
14+
model_params_m,
15+
):
16+
"""Helper function to log summary table in vertical format."""
17+
18+
table_data = {
19+
"Model Name": model_name_or_path,
20+
"Token Loss (nats)": f"{eval_loss:.5f}"
21+
if isinstance(eval_loss, numbers.Real) and not math.isnan(eval_loss)
22+
else str(eval_loss),
23+
"BPC": f"{bpc_metrics['bpc']:.5f}"
24+
if not math.isinf(bpc_metrics["bpc"]) and not math.isnan(bpc_metrics["bpc"])
25+
else str(bpc_metrics["bpc"]),
26+
"bPPL": f"{bpc_metrics['bppl']:.5f}"
27+
if not math.isinf(bpc_metrics["bppl"]) and not math.isnan(bpc_metrics["bppl"])
28+
else str(bpc_metrics["bppl"]),
29+
"T/B Ratio": f"{token_byte_ratio:.4f}"
30+
if not math.isinf(token_byte_ratio) and not math.isnan(token_byte_ratio)
31+
else str(token_byte_ratio),
32+
"Target Tokens": str(total_target_tokens),
33+
"Target Bytes": str(total_bytes),
34+
"Vocab Size": str(vocab_size),
35+
"Total Params (M)": f"{model_params_m:.2f}"
36+
if isinstance(model_params_m, numbers.Real) and not math.isnan(model_params_m)
37+
else str(model_params_m),
38+
}
39+
40+
label_width = max(len(label) for label in table_data.keys())
41+
value_width = max(len(str(value)) for value in table_data.values())
42+
total_width = label_width + value_width + 3
43+
44+
header = (
45+
"=" * ((total_width - 20) // 2)
46+
+ " Validation Summary "
47+
+ "=" * ((total_width - 20) // 2)
48+
)
49+
logger.info(f"\n{header}")
50+
51+
for label, value in table_data.items():
52+
if label == "Model Name" and len(value) > value_width:
53+
value = value[: value_width - 3] + "..."
54+
print(f"{label:<{label_width}} | {value:<{value_width}}")
55+
56+
print("=" * total_width + "\n")

src/core/loss.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import math
2+
import numbers
3+
4+
5+
def calculate_bpc_bppl_metrics(eval_loss, total_target_tokens, total_bytes):
6+
"""
7+
Calculates BPC (Bits Per Character) and bPPL (bits Per Character Perplexity).
8+
9+
Args:
10+
eval_loss (float): Average token-level loss in nats.
11+
total_target_tokens (int): Total number of target tokens.
12+
total_bytes (int): Total number of target bytes.
13+
14+
Returns:
15+
dict: A dictionary containing 'bpc', 'bppl', 'nll_token_nats_total',
16+
'nll_token_bits_total'.
17+
Returns values like {'bpc': float('inf'), 'bppl': float('inf'), ...}
18+
if total_bytes is 0, eval_loss is invalid (non-real, NaN, or infinity).
19+
'bppl' will also be float('inf') if bpc is float('inf') or if
20+
math.pow(2, bpc) calculation overflows for a large finite bpc.
21+
"""
22+
if (
23+
total_bytes == 0
24+
or not isinstance(eval_loss, numbers.Real)
25+
or math.isnan(eval_loss)
26+
or math.isinf(eval_loss)
27+
):
28+
return {
29+
"bpc": float("inf"),
30+
"bppl": float("inf"),
31+
"nll_token_nats_total": float("nan"),
32+
"nll_token_bits_total": float("nan"),
33+
}
34+
35+
nll_token_nats_total = eval_loss * total_target_tokens
36+
nll_token_bits_total = nll_token_nats_total / math.log(2)
37+
bpc = nll_token_bits_total / total_bytes
38+
39+
if math.isinf(bpc):
40+
bppl = float("inf")
41+
else:
42+
try:
43+
bppl = math.pow(2, bpc)
44+
except OverflowError:
45+
bppl = float("inf")
46+
47+
return {
48+
"bpc": bpc,
49+
"bppl": bppl,
50+
"nll_token_nats_total": nll_token_nats_total,
51+
"nll_token_bits_total": nll_token_bits_total,
52+
}
53+
54+
55+
def get_token_byte_ratio(total_target_tokens, total_bytes):
56+
"""
57+
Calculates the token to byte ratio.
58+
59+
Args:
60+
total_target_tokens (int): Total number of target tokens.
61+
total_bytes (int): Total number of target bytes.
62+
63+
Returns:
64+
float: The token to byte ratio. Returns float('inf') if total_bytes is 0.
65+
"""
66+
if total_bytes == 0:
67+
return float("inf")
68+
return total_target_tokens / total_bytes
69+
70+
71+
def calculate_bytes_and_tokens(eval_dataset, tokenizer, logger):
72+
"""
73+
Calculates total bytes and target tokens in the evaluation dataset.
74+
75+
Args:
76+
eval_dataset: The evaluation dataset.
77+
tokenizer: The tokenizer.
78+
logger: The logger instance.
79+
80+
Returns:
81+
tuple: A tuple containing total_bytes and total_target_tokens.
82+
"""
83+
total_bytes = 0
84+
total_target_tokens = 0
85+
logger.info(
86+
"Calculating total bytes and target tokens in the evaluation dataset..."
87+
)
88+
for i in range(len(eval_dataset)):
89+
item = eval_dataset[i]
90+
target_ids = [
91+
id for id, mask in zip(item["input_ids"], item["target_mask"]) if mask == 1
92+
]
93+
if target_ids:
94+
target_text = tokenizer.decode(target_ids, skip_special_tokens=True)
95+
total_bytes += len(target_text.encode("utf-8"))
96+
total_target_tokens += len(target_ids)
97+
return total_bytes, total_target_tokens

src/validate.py

Lines changed: 88 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import numbers
23
import os
34
import time
45
import shutil
@@ -31,12 +32,19 @@
3132
handle_runtime_error,
3233
handle_value_error,
3334
)
35+
from core.loss import (
36+
calculate_bpc_bppl_metrics,
37+
get_token_byte_ratio,
38+
calculate_bytes_and_tokens,
39+
)
40+
from core.log_utils import _log_summary_table
3441
from tenacity import retry, stop_after_attempt, wait_exponential
3542
from client.fed_ledger import FedLedger
3643
from peft import PeftModel
3744
import sys
3845
import math
3946

47+
4048
load_dotenv()
4149
TIME_SLEEP = int(os.getenv("TIME_SLEEP", 60 * 3))
4250
ASSIGNMENT_LOOKUP_INTERVAL = 60 * 3 # 3 minutes
@@ -292,6 +300,14 @@ def validate(
292300

293301
model = None
294302
eval_dataset = None
303+
bpc_metrics_results = {
304+
"bpc": float("inf"),
305+
"bppl": float("inf"),
306+
"nll_token_nats_total": float("nan"),
307+
"nll_token_bits_total": float("nan"),
308+
}
309+
token_byte_ratio_value = float("inf")
310+
eval_loss = float("nan") # Initialize eval_loss
295311

296312
try:
297313
fed_ledger = FedLedger(FLOCK_API_KEY)
@@ -379,6 +395,28 @@ def validate(
379395
eval_dataset = load_sft_dataset(
380396
eval_file, context_length, template_name=base_model, tokenizer=tokenizer
381397
)
398+
399+
total_bytes, total_target_tokens = calculate_bytes_and_tokens(
400+
eval_dataset, tokenizer, logger
401+
)
402+
403+
if total_bytes == 0:
404+
logger.warning(
405+
"Total bytes in the evaluation dataset is 0. Cannot calculate BPC. Check dataset processing."
406+
)
407+
eval_loss_to_submit = LOSS_FOR_MODEL_PARAMS_EXCEED
408+
else:
409+
logger.info(f"Total target bytes (B): {total_bytes}")
410+
logger.info(f"Total target tokens (T): {total_target_tokens}")
411+
token_byte_ratio_value = get_token_byte_ratio(
412+
total_target_tokens, total_bytes
413+
)
414+
logger.info(f"Token/Byte ratio (T/B): {token_byte_ratio_value:.4f}")
415+
if token_byte_ratio_value < 0.1:
416+
logger.warning(
417+
f"Token/Byte ratio ({token_byte_ratio_value:.4f}) is unusually low. Potential manipulation detected."
418+
)
419+
382420
model = load_model(
383421
model_name_or_path, lora_only, revision, val_args, cached_lora
384422
)
@@ -413,19 +451,60 @@ def validate(
413451
data_collator=data_collator,
414452
)
415453

454+
logger.info("Starting evaluation...")
416455
eval_result = trainer.evaluate()
417456
eval_loss = eval_result["eval_loss"]
418-
logger.info("evaluate result is %s" % str(eval_result))
457+
458+
logger.info("Raw evaluation result: %s" % str(eval_result))
459+
460+
if total_bytes > 0:
461+
bpc_metrics_results = calculate_bpc_bppl_metrics(
462+
eval_loss, total_target_tokens, total_bytes
463+
)
464+
465+
is_bpc_valid = not math.isinf(bpc_metrics_results["bpc"])
466+
467+
_log_summary_table(
468+
model_name_or_path=model_name_or_path,
469+
eval_loss=eval_loss,
470+
bpc_metrics=bpc_metrics_results,
471+
token_byte_ratio=token_byte_ratio_value,
472+
total_target_tokens=total_target_tokens,
473+
total_bytes=total_bytes,
474+
vocab_size=tokenizer.vocab_size,
475+
model_params_m=(sum(p.numel() for p in model.parameters()) / 1e6)
476+
if model
477+
else float("nan"),
478+
)
479+
419480
if local_test:
420-
logger.info("The model can be correctly validated by validators.")
481+
logger.info(
482+
"The model can be correctly validated by validators (raw loss)."
483+
)
484+
if not is_bpc_valid: # If BPC is inf
485+
logger.warning(
486+
"Could not calculate BPC/bPPL for local test due to zero bytes or invalid loss."
487+
)
421488
return
422-
# sometimes the loss might not be a valid float
423-
if isinstance(eval_loss, float) and (
424-
math.isnan(eval_loss) or math.isinf(eval_loss)
425-
):
426-
eval_loss = LOSS_FOR_MODEL_PARAMS_EXCEED
489+
490+
eval_loss_to_submit = LOSS_FOR_MODEL_PARAMS_EXCEED # Default to high loss
491+
492+
if is_bpc_valid:
493+
eval_loss_to_submit = bpc_metrics_results["bpc"]
494+
else:
495+
if total_bytes == 0:
496+
logger.error("Total bytes is 0, submitting high loss.")
497+
elif (
498+
not isinstance(eval_loss, numbers.Real)
499+
or math.isnan(eval_loss)
500+
or math.isinf(eval_loss)
501+
):
502+
logger.error(f"Invalid eval_loss ({eval_loss}), submitting high loss.")
503+
427504
resp = fed_ledger.submit_validation_result(
428-
assignment_id=assignment_id, loss=eval_loss, gpu_type=gpu_type
505+
assignment_id=assignment_id,
506+
loss=eval_loss_to_submit, # Submit BPC as loss
507+
gpu_type=gpu_type,
429508
)
430509
# check response is 200
431510
if resp.status_code != 200:
@@ -439,10 +518,9 @@ def validate(
439518
fed_ledger.mark_assignment_as_failed(assignment_id)
440519
return
441520
logger.info(
442-
f"Successfully submitted validation result for assignment {assignment_id}"
521+
f"Successfully submitted validation result (BPC: {eval_loss_to_submit}) for assignment {assignment_id}"
443522
)
444523

445-
# raise for exceptions, will handle at `loop` level
446524
except Exception as e:
447525
raise e
448526
finally:

tests/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Tests package for llm-loss-validator

tests/core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Core module tests for llm-loss-validator

0 commit comments

Comments
 (0)