Skip to content

Commit 2427205

Browse files
Implement Laplace smoothing in dialog log-likelihood computation
1 parent 2a49d7d commit 2427205

File tree

1 file changed

+35
-4
lines changed

1 file changed

+35
-4
lines changed

src/sdialog/evaluation/base.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,15 @@ def __init__(self,
283283
if node_id[0].lower() == "u"],
284284
k=k_neighbors)
285285

286+
# Pre-compute vocabulary size for Laplace smoothing (exclude metadata nodes)
287+
self.vocab_size = len([node_id for node_id in self.nodes.keys() if not node_id.startswith("_")])
288+
289+
# Pre-compute total outbound frequency counts for each node (for Laplace smoothing)
290+
self.outbound_totals = {
291+
node: sum(d["fr"] for _, _, d in self.graph.out_edges(node, data=True))
292+
for node in self.graph.nodes()
293+
}
294+
286295
def get_node_sequence(self, dialog: Dialog, probs: bool = False) -> List[str]:
287296
"""
288297
Map each turn to its nearest node and optionally return transition probabilities.
@@ -324,12 +333,23 @@ def get_node_sequence(self, dialog: Dialog, probs: bool = False) -> List[str]:
324333

325334
def compute_dialog_log_likelihood(self, dialog: Dialog) -> Tuple[float, int]:
326335
"""
327-
Compute cumulative log-probability statistics for a dialog.
336+
Compute cumulative log-probability statistics for a dialog using Laplace smoothing.
337+
338+
Laplace smoothing approach (add-one smoothing):
339+
- For each transition: P_laplace(dest|src) = (count(src→dest) + 1) / (sum_outbound_counts + V)
340+
- Known edges: Use edge frequency count + 1
341+
- Unknown edges: Use count of 1 (equivalent to adding a pseudo-count)
342+
- Both are normalized by (total_outbound_count + V) where V is vocabulary size
343+
344+
This provides a principled probability distribution that:
345+
1. Smooths all transitions (known and unknown) consistently
346+
2. Avoids zero probabilities for unseen transitions
347+
3. Doesn't modify the graph structure (scoring-time smoothing)
328348
329349
Returns four values:
330350
sum_log_p_known: Sum of log probabilities only over known edges.
331351
n_turns_known: Count of contributing turns with known edges (includes initial offset).
332-
sum_log_p: Sum over all considered turns (unknown edges use uniform fallback).
352+
sum_log_p: Sum over all considered turns (with Laplace smoothing).
333353
n_turns: Total counted turns (includes initial offset; respects ai_speaker filtering).
334354
335355
:param dialog: Dialog to evaluate.
@@ -341,8 +361,10 @@ def compute_dialog_log_likelihood(self, dialog: Dialog) -> Tuple[float, int]:
341361
sum_log_p, sum_log_p_known = 0, 0
342362
n_turns, n_turns_known = 1, 1 # start with 1 to account for the first turn and avoid division by zero
343363
prev_node = DEFAULT_TOKEN_START
364+
344365
if self.only_system:
345366
dialog = [turn for turn in dialog if turn.speaker.lower() == self.ai_speaker.lower()]
367+
346368
for turn in dialog:
347369
speaker = turn.speaker.lower()
348370
if speaker in self.speakers:
@@ -355,14 +377,23 @@ def compute_dialog_log_likelihood(self, dialog: Dialog) -> Tuple[float, int]:
355377
current_node, _ = neighbors[0]
356378
prob_correct_node = softmax([1 - dist for _, dist in neighbors])[0] if self.use_softmax else 1
357379

380+
# Get total outbound count for source node (for Laplace smoothing denominator)
381+
total_outbound_count = self.outbound_totals.get(prev_node, 0)
382+
358383
prob_current_node = self.graph.get_edge_data(prev_node, current_node)
359384
if prob_current_node is not None:
360-
log_p = log(prob_current_node["weight"] * prob_correct_node)
385+
# Known edge: Laplace smoothing = (count + 1) / (total + V)
386+
edge_count = prob_current_node["fr"]
387+
smoothed_prob = (edge_count + 1) / (total_outbound_count + self.vocab_size)
388+
log_p = log(smoothed_prob * prob_correct_node)
361389
sum_log_p += log_p
362390
sum_log_p_known += log_p
363391
n_turns_known += 1
364392
else:
365-
sum_log_p += log(1e-12) # fallback for unknown edges
393+
# Unknown edge: Laplace smoothing = 1 / (total + V)
394+
smoothed_prob = 1 / (total_outbound_count + self.vocab_size)
395+
sum_log_p += log(smoothed_prob)
396+
366397
n_turns += 1
367398
prev_node = current_node
368399

0 commit comments

Comments
 (0)