Skip to content
135 changes: 135 additions & 0 deletions graphgen/bases/base_splitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import copy
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, Iterable, List, Literal, Optional, Union

from graphgen.bases.datatypes import Chunk
from graphgen.utils import logger


@dataclass
class BaseSplitter(ABC):
"""
Abstract base class for splitting text into smaller chunks.
"""

chunk_size: int = 1024
chunk_overlap: int = 100
length_function: Callable[[str], int] = len
keep_separator: bool = False
add_start_index: bool = False
strip_whitespace: bool = True

@abstractmethod
def split_text(self, text: str) -> List[str]:
"""
Split the input text into smaller chunks.

:param text: The input text to be split.
:return: A list of text chunks.
"""

def create_chunks(
self, texts: List[str], metadatas: Optional[List[dict]] = None
) -> List[Chunk]:
"""Create chunks from a list of texts."""
_metadatas = metadatas or [{}] * len(texts)
chunks = []
for i, text in enumerate(texts):
index = 0
previous_chunk_len = 0
for chunk in self.split_text(text):
metadata = copy.deepcopy(_metadatas[i])
if self.add_start_index:
offset = index + previous_chunk_len - self.chunk_overlap
index = text.find(chunk, max(0, offset))
metadata["start_index"] = index
previous_chunk_len = len(chunk)
new_chunk = Chunk(content=chunk, metadata=metadata)
Copy link

Copilot AI Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating Chunk instances without providing the required id field will cause runtime errors. The Chunk dataclass requires all three fields (id, content, metadata) but only content and metadata are being provided.

Copilot uses AI. Check for mistakes.
chunks.append(new_chunk)
return chunks

def _join_chunks(self, chunks: List[str], separator: str) -> Optional[str]:
text = separator.join(chunks)
if self.strip_whitespace:
text = text.strip()
if text == "":
return None
return text

def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
# We now want to combine these smaller pieces into medium size chunks to send to the LLM.
separator_len = self.length_function(separator)

chunks = []
current_chunk: List[str] = []
total = 0
for d in splits:
_len = self.length_function(d)
if (
total + _len + (separator_len if len(current_chunk) > 0 else 0)
> self.chunk_size
):
if total > self.chunk_size:
logger.warning(
"Created a chunk of size %s, which is longer than the specified %s",
total,
self.chunk_size,
)
if len(current_chunk) > 0:
chunk = self._join_chunks(current_chunk, separator)
if chunk is not None:
chunks.append(chunk)
# Keep on popping if:
# - we have a larger chunk than in the chunk overlap
# - or if we still have any chunks and the length is long
while total > self.chunk_overlap or (
total + _len + (separator_len if len(current_chunk) > 0 else 0)
> self.chunk_size
and total > 0
):
total -= self.length_function(current_chunk[0]) + (
separator_len if len(current_chunk) > 1 else 0
)
current_chunk = current_chunk[1:]
current_chunk.append(d)
total += _len + (separator_len if len(current_chunk) > 1 else 0)
chunk = self._join_chunks(current_chunk, separator)
if chunk is not None:
chunks.append(chunk)
return chunks

@staticmethod
def _split_text_with_regex(
text: str, separator: str, keep_separator: Union[bool, Literal["start", "end"]]
) -> List[str]:
# Now that we have the separator, split the text
if separator:
if keep_separator:
# The parentheses in the pattern keep the delimiters in the result.
_splits = re.split(f"({separator})", text)
splits = (
(
[
_splits[i] + _splits[i + 1]
for i in range(0, len(_splits) - 1, 2)
]
)
if keep_separator == "end"
else (
[_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
)
)
if len(_splits) % 2 == 0:
splits += _splits[-1:]
splits = (
(splits + [_splits[-1]])
if keep_separator == "end"
else ([_splits[0]] + splits)
)
else:
splits = re.split(separator, text)
else:
splits = list(text)
return [s for s in splits if s != ""]
18 changes: 18 additions & 0 deletions graphgen/bases/datatypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from dataclasses import dataclass


@dataclass
class Chunk:
id: str
content: str
metadata: dict


@dataclass
class QAPair:
"""
A pair of question and answer.
"""

question: str
answer: str
134 changes: 82 additions & 52 deletions graphgen/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,31 @@
"""Evaluate the quality of the generated text using various metrics"""

import os
import json
import argparse
import json
import os

import pandas as pd
from dotenv import load_dotenv
from .models import LengthEvaluator, MTLDEvaluator, RewardEvaluator, TextPair, UniEvaluator

from graphgen.bases.datatypes import QAPair

from .models import LengthEvaluator, MTLDEvaluator, RewardEvaluator, UniEvaluator
from .utils import logger, set_logger

sys_path = os.path.abspath(os.path.dirname(__file__))
set_logger(os.path.join(sys_path, "cache", "logs", "evaluate.log"))

load_dotenv()


def evaluate_length(corpus, tokenizer_name):
length_evaluator = LengthEvaluator(
tokenizer_name=tokenizer_name
)
length_evaluator = LengthEvaluator(tokenizer_name=tokenizer_name)
logger.info("Length evaluator loaded")
scores = length_evaluator.get_average_score(corpus)
logger.info("Length scores: %s", scores)
return scores


def evaluate_mtld(corpus):
mtld_evaluator = MTLDEvaluator()
logger.info("MTLD evaluator loaded")
Expand All @@ -31,30 +35,30 @@ def evaluate_mtld(corpus):
logger.info("MTLD min max scores: %s", min_max_scores)
return scores, min_max_scores


def evaluate_reward(corpus, reward_model_names):
scores = []
for reward_name in reward_model_names:
reward_evaluator = RewardEvaluator(
reward_name=reward_name
)
reward_evaluator = RewardEvaluator(reward_name=reward_name)
logger.info("Loaded reward model: %s", reward_name)
average_score = reward_evaluator.get_average_score(corpus)
logger.info("%s scores: %s", reward_name, average_score)
min_max_scores = reward_evaluator.get_min_max_score(corpus)
logger.info("%s min max scores: %s", reward_name, min_max_scores)
scores.append({
'reward_name': reward_name.split('/')[-1],
'score': average_score,
'min_max_scores': min_max_scores
})
scores.append(
{
"reward_name": reward_name.split("/")[-1],
"score": average_score,
"min_max_scores": min_max_scores,
}
)
del reward_evaluator
clean_gpu_cache()
return scores


def evaluate_uni(corpus, uni_model_name):
uni_evaluator = UniEvaluator(
model_name=uni_model_name
)
uni_evaluator = UniEvaluator(model_name=uni_model_name)
logger.info("Uni evaluator loaded with model %s", uni_model_name)
uni_scores = uni_evaluator.get_average_score(corpus)
for key, value in uni_scores.items():
Expand All @@ -64,27 +68,47 @@ def evaluate_uni(corpus, uni_model_name):
logger.info("Uni %s min max scores: %s", key, value)
del uni_evaluator
clean_gpu_cache()
return (uni_scores['naturalness'], uni_scores['coherence'], uni_scores['understandability'],
min_max_scores['naturalness'], min_max_scores['coherence'], min_max_scores['understandability'])
return (
uni_scores["naturalness"],
uni_scores["coherence"],
uni_scores["understandability"],
min_max_scores["naturalness"],
min_max_scores["coherence"],
min_max_scores["understandability"],
)


def clean_gpu_cache():
import torch

if torch.cuda.is_available():
torch.cuda.empty_cache()


if __name__ == '__main__':
if __name__ == "__main__":
import torch.multiprocessing as mp

parser = argparse.ArgumentParser()

parser.add_argument('--folder', type=str, default='cache/data', help='folder to load data')
parser.add_argument('--output', type=str, default='cache/output', help='path to save output')
parser.add_argument(
"--folder", type=str, default="cache/data", help="folder to load data"
)
parser.add_argument(
"--output", type=str, default="cache/output", help="path to save output"
)

parser.add_argument('--tokenizer', type=str, default='cl100k_base', help='tokenizer name')
parser.add_argument('--reward', type=str, default='OpenAssistant/reward-model-deberta-v3-large-v2',
help='Comma-separated list of reward models')
parser.add_argument('--uni', type=str, default='MingZhong/unieval-sum', help='uni model name')
parser.add_argument(
"--tokenizer", type=str, default="cl100k_base", help="tokenizer name"
)
parser.add_argument(
"--reward",
type=str,
default="OpenAssistant/reward-model-deberta-v3-large-v2",
help="Comma-separated list of reward models",
)
parser.add_argument(
"--uni", type=str, default="MingZhong/unieval-sum", help="uni model name"
)

args = parser.parse_args()

Expand All @@ -94,49 +118,55 @@ def clean_gpu_cache():
if not os.path.exists(args.output):
os.makedirs(args.output)

reward_models = args.reward.split(',')

reward_models = args.reward.split(",")

results = []

logger.info("Data loaded from %s", args.folder)
mp.set_start_method('spawn')
mp.set_start_method("spawn")

for file in os.listdir(args.folder):
if file.endswith('.json'):
if file.endswith(".json"):
logger.info("Processing %s", file)
with open(os.path.join(args.folder, file), 'r', encoding='utf-8') as f:
with open(os.path.join(args.folder, file), "r", encoding="utf-8") as f:
data = json.load(f)
data = [TextPair(
question=data[key]['question'],
answer=data[key]['answer']
) for key in data]
data = [
QAPair(question=data[key]["question"], answer=data[key]["answer"])
for key in data
]

length_scores = evaluate_length(data, args.tokenizer)
mtld_scores, min_max_mtld_scores = evaluate_mtld(data)
reward_scores = evaluate_reward(data, reward_models)
uni_naturalness_scores, uni_coherence_scores, uni_understandability_scores, \
min_max_uni_naturalness_scores, min_max_uni_coherence_scores, min_max_uni_understandability_scores \
= evaluate_uni(data, args.uni)
(
uni_naturalness_scores,
uni_coherence_scores,
uni_understandability_scores,
min_max_uni_naturalness_scores,
min_max_uni_coherence_scores,
min_max_uni_understandability_scores,
) = evaluate_uni(data, args.uni)

result = {
'file': file,
'number': len(data),
'length': length_scores,
'mtld': mtld_scores,
'mtld_min_max': min_max_mtld_scores,
'uni_naturalness': uni_naturalness_scores,
'uni_coherence': uni_coherence_scores,
'uni_understandability': uni_understandability_scores,
'uni_naturalness_min_max': min_max_uni_naturalness_scores,
'uni_coherence_min_max': min_max_uni_coherence_scores,
'uni_understandability_min_max': min_max_uni_understandability_scores
"file": file,
"number": len(data),
"length": length_scores,
"mtld": mtld_scores,
"mtld_min_max": min_max_mtld_scores,
"uni_naturalness": uni_naturalness_scores,
"uni_coherence": uni_coherence_scores,
"uni_understandability": uni_understandability_scores,
"uni_naturalness_min_max": min_max_uni_naturalness_scores,
"uni_coherence_min_max": min_max_uni_coherence_scores,
"uni_understandability_min_max": min_max_uni_understandability_scores,
}
for reward_score in reward_scores:
result[reward_score['reward_name']] = reward_score['score']
result[f"{reward_score['reward_name']}_min_max"] = reward_score['min_max_scores']
result[reward_score["reward_name"]] = reward_score["score"]
result[f"{reward_score['reward_name']}_min_max"] = reward_score[
"min_max_scores"
]

results.append(result)

results = pd.DataFrame(results)
results.to_csv(os.path.join(args.output, 'evaluation.csv'), index=False)
results.to_csv(os.path.join(args.output, "evaluation.csv"), index=False)
2 changes: 1 addition & 1 deletion graphgen/graphgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from tqdm.asyncio import tqdm as tqdm_async

from graphgen.bases.base_storage import StorageNameSpace
from graphgen.bases.datatypes import Chunk
from graphgen.models import (
Chunk,
JsonKVStorage,
JsonListStorage,
NetworkXStorage,
Expand Down
Loading