diff --git a/CHANGELOG.md b/CHANGELOG.md index a590afe5e190..d9f780d15522 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,17 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## [Unreleased] + +### Added + +- Added local LLM support for TXT2KG knowledge graph extraction ([#10479](https://github.com/pyg-team/pytorch_geometric/pull/10479)) +- Added uncertainty estimation test for LLM class ([#10479](https://github.com/pyg-team/pytorch_geometric/pull/10479)) + +### Fixed + +- Fixed LLM import path in TXT2KG from `torch_geometric.nn` to `torch_geometric.llm.models` ([#10479](https://github.com/pyg-team/pytorch_geometric/pull/10479)) + ## [2.7.0] - 2025-MM-DD ### Fixed diff --git a/examples/llm/txt2kg_rag.py b/examples/llm/txt2kg_rag.py index 09310ac4802e..ede39a46c6ba 100644 --- a/examples/llm/txt2kg_rag.py +++ b/examples/llm/txt2kg_rag.py @@ -14,6 +14,7 @@ try: import wandb + wandb_available = True except ImportError: wandb_available = False @@ -39,6 +40,10 @@ LLMJudge, SentenceTransformer, ) +# EDFL planner + universal backends (vendor these, +# or pip-install if you packaged them): +# from hallucination_toolkit import OpenAIPlanner, OpenAIItem # imported +# internally by LLM now from torch_geometric.llm.models.txt2kg import _chunk_text from torch_geometric.llm.utils.backend_utils import ( create_graph_from_triples, @@ -70,92 +75,162 @@ def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('--gnn_model', type=str, default="GAT", - choices=["GAT", "SGFormer"], - help="The GNN model to use. Default is GAT.") - parser.add_argument('--NV_NIM_MODEL', type=str, - default=NV_NIM_MODEL_DEFAULT, - help="The NIM LLM to use for TXT2KG for LLMJudge") - parser.add_argument('--NV_NIM_KEY', type=str, help="NVIDIA API key") parser.add_argument( - '--ENDPOINT_URL', type=str, default=DEFAULT_ENDPOINT_URL, + "--gnn_model", + type=str, + default="GAT", + choices=["GAT", "SGFormer"], + help="The GNN model to use. Default is GAT.", + ) + parser.add_argument( + "--NV_NIM_MODEL", + type=str, + default=NV_NIM_MODEL_DEFAULT, + help="The NIM LLM to use for TXT2KG for LLMJudge", + ) + parser.add_argument("--NV_NIM_KEY", type=str, help="NVIDIA API key") + parser.add_argument( + "--ENDPOINT_URL", + type=str, + default=DEFAULT_ENDPOINT_URL, help="The URL hosting your model, \ - in case you are not using the public NIM.") + in case you are not using the public NIM.", + ) + parser.add_argument( + "--use_local_txt2kg", + action="store_true", + help="Use local LLM for TXT2KG instead of NVIDIA NIM", + ) + parser.add_argument( + "--txt2kg_model", + type=str, + default="mistralai/Mistral-7B-Instruct-v0.3", + help="Local model for TXT2KG", + ) parser.add_argument( - '--kg_chunk_size', type=int, default=KG_CHUNK_SIZE_DEFAULT, + "--kg_chunk_size", + type=int, + default=KG_CHUNK_SIZE_DEFAULT, help="When splitting context documents for txt2kg,\ - the maximum number of characters per chunk.") - parser.add_argument('--gnn_hidden_channels', type=int, - default=GNN_HID_CHANNELS_DEFAULT, - help="Hidden channels for GNN") - parser.add_argument('--num_gnn_layers', type=int, - default=GNN_LAYERS_DEFAULT, - help="Number of GNN layers") - parser.add_argument('--lr', type=float, default=LR_DEFAULT, + the maximum number of characters per chunk.", + ) + parser.add_argument( + "--gnn_hidden_channels", + type=int, + default=GNN_HID_CHANNELS_DEFAULT, + help="Hidden channels for GNN", + ) + parser.add_argument( + "--num_gnn_layers", + type=int, + default=GNN_LAYERS_DEFAULT, + help="Number of GNN layers", + ) + parser.add_argument("--lr", type=float, default=LR_DEFAULT, help="Learning rate") - parser.add_argument('--epochs', type=int, default=EPOCHS_DEFAULT, + parser.add_argument("--epochs", type=int, default=EPOCHS_DEFAULT, help="Number of epochs") - parser.add_argument('--batch_size', type=int, default=BATCH_SIZE_DEFAULT, + parser.add_argument("--batch_size", type=int, default=BATCH_SIZE_DEFAULT, help="Batch size") - parser.add_argument('--eval_batch_size', type=int, - default=EVAL_BATCH_SIZE_DEFAULT, - help="Evaluation batch size") - parser.add_argument('--llm_generator_name', type=str, - default=LLM_GENERATOR_NAME_DEFAULT, - help="The LLM to use for Generation") parser.add_argument( - '--llm_generator_mode', type=str, default=LLM_GEN_MODE_DEFAULT, - choices=["frozen", "lora", - "full"], help="Whether to freeze the Generator LLM,\ - use LORA, or fully finetune") - parser.add_argument('--dont_save_model', action="store_true", - help="Whether to skip model saving.") - parser.add_argument('--log_steps', type=int, default=30, + "--eval_batch_size", + type=int, + default=EVAL_BATCH_SIZE_DEFAULT, + help="Evaluation batch size", + ) + parser.add_argument( + "--llm_generator_name", + type=str, + default=LLM_GENERATOR_NAME_DEFAULT, + help="The LLM to use for Generation", + ) + parser.add_argument( + "--llm_generator_mode", + type=str, + default=LLM_GEN_MODE_DEFAULT, + choices=["frozen", "lora", "full"], + help="Whether to freeze the Generator LLM,\ + use LORA, or fully finetune", + ) + parser.add_argument( + "--dont_save_model", + action="store_true", + help="Whether to skip model saving.", + ) + parser.add_argument("--log_steps", type=int, default=30, help="Log to wandb every N steps") - parser.add_argument('--wandb_project', type=str, default="techqa", - help="Weights & Biases project name") - parser.add_argument('--wandb', action="store_true", + parser.add_argument( + "--wandb_project", + type=str, + default="techqa", + help="Weights & Biases project name", + ) + parser.add_argument("--wandb", action="store_true", help="Enable wandb logging") parser.add_argument( - '--num_gpus', type=int, default=None, + "--num_gpus", + type=int, + default=None, help="Number of GPUs to use. If not specified," - "will determine automatically based on model size.") - parser.add_argument('--regenerate_dataset', action="store_true", - help="Regenerate the dataset") + "will determine automatically based on model size.", + ) + parser.add_argument( + "--regenerate_dataset", + action="store_true", + help="Regenerate the dataset", + ) parser.add_argument( - '--doc_parsing_mode', type=str, default=None, - choices=["paragraph", - "file"], help="How to parse documents: 'paragraph' splits " + "--doc_parsing_mode", + type=str, + default=None, + choices=["paragraph", "file"], + help="How to parse documents: 'paragraph' splits " "files by paragraphs, 'file' treats each file as" "one document. " - "This will override any value set in the config file.") + "This will override any value set in the config file.", + ) parser.add_argument( - '--k_for_docs', type=int, default=None, + "--k_for_docs", + type=int, + default=None, help="Number of docs to retrieve for each question. " - "This will override any value set in the config file.") + "This will override any value set in the config file.", + ) parser.add_argument( - '--doc_chunk_size', type=int, default=None, + "--doc_chunk_size", + type=int, + default=None, help="The chunk size to use VectorRAG (document retrieval). " - "This will override any value set in the config file.") + "This will override any value set in the config file.", + ) parser.add_argument( - '--dataset', type=str, default="techqa", help="Dataset folder name, " + "--dataset", + type=str, + default="techqa", + help="Dataset folder name, " "should contain corpus and train.json files." "extracted triples, processed dataset, " "document retriever, and model checkpoints " - "will be saved in the dataset folder") + "will be saved in the dataset folder", + ) parser.add_argument( - '--skip_graph_rag', action="store_true", + "--skip_graph_rag", + action="store_true", help="Skip the graph RAG step. " - "Used to compare the performance of Vector+Graph RAG vs Vector RAG.") + "Used to compare the performance of Vector+Graph RAG vs Vector RAG.", + ) parser.add_argument( - '--use_x_percent_corpus', default=100.0, type=float, + "--use_x_percent_corpus", + default=100.0, + type=float, help="Debug flag that allows user to only use a random percentage " - "of available knowledge base corpus for RAG") + "of available knowledge base corpus for RAG", + ) args = parser.parse_args() assert args.NV_NIM_KEY, "NVIDIA API key is required for TXT2KG and eval" - assert args.use_x_percent_corpus <= 100 and \ - args.use_x_percent_corpus > 0, "Please provide a value in (0,100]" + assert (args.use_x_percent_corpus <= 100 and args.use_x_percent_corpus + > 0), "Please provide a value in (0,100]" if args.skip_graph_rag: print("Skipping graph RAG step, setting GNN layers to 0...") args.num_gnn_layers = 0 @@ -169,7 +244,9 @@ def parse_args(): if config is not None: # Use a loop to check and apply config values for each parameter config_params = [ - 'doc_parsing_mode', 'doc_chunk_size', 'k_for_docs' + "doc_parsing_mode", + "doc_chunk_size", + "k_for_docs", ] for param in config_params: if param in config and getattr(args, param) is None: @@ -214,7 +291,7 @@ def _process_and_chunk_text(text, chunk_size, doc_parsing_mode): if multiple paragraphs are detected. """ if doc_parsing_mode == "paragraph": - paragraphs = re.split(r'\n{2,}', text) + paragraphs = re.split(r"\n{2,}", text) else: # doc_parsing_mode == 'file' or doc_parsing_mode is None paragraphs = [text] @@ -258,9 +335,11 @@ def get_data(args): if not os.path.exists(args.dataset): os.mkdir(args.dataset) import zipfile - with zipfile.ZipFile(zip_path, 'r') as zip_ref: + + with zipfile.ZipFile(zip_path, "r") as zip_ref: zip_ref.extractall(args.dataset) import shutil + shutil.copy(json_path, os.path.join(args.dataset, "train.json")) elif user_input.lower() == "n" or user_input.lower() == "no": sys.exit("No selected, no data to work with... exiting.") @@ -285,9 +364,11 @@ def get_data(args): raise ValueError(f"Bad extraction for {file_path}, expecting " f"text only but got {doc_type}") text_contexts.extend( - _process_and_chunk_text(data[0]["metadata"]["content"], - args.doc_chunk_size, - args.doc_parsing_mode)) + _process_and_chunk_text( + data[0]["metadata"]["content"], + args.doc_chunk_size, + args.doc_parsing_mode, + )) else: for file_path in glob(os.path.join(args.dataset, "corpus", "*")): with open(file_path, "r+") as f: @@ -304,19 +385,29 @@ def get_data(args): def index_kg(args, context_docs): - kg_maker = TXT2KG(NVIDIA_NIM_MODEL=args.NV_NIM_MODEL, - NVIDIA_API_KEY=args.NV_NIM_KEY, - ENDPOINT_URL=args.ENDPOINT_URL, - chunk_size=args.kg_chunk_size) + if args.use_local_txt2kg: + kg_maker = TXT2KG( + local_LM=True, + local_LM_model_name=args.txt2kg_model, + chunk_size=args.kg_chunk_size, + ) + else: + kg_maker = TXT2KG( + NVIDIA_NIM_MODEL=args.NV_NIM_MODEL, + NVIDIA_API_KEY=args.NV_NIM_KEY, + ENDPOINT_URL=args.ENDPOINT_URL, + chunk_size=args.kg_chunk_size, + ) print( "Note that if the TXT2KG process is too slow for you're liking using " "the public NIM, consider deploying yourself using local_lm flag of " - "TXT2KG or using https://build.nvidia.com/nvidia/llama-3_1-nemotron-70b-instruct " # noqa + "TXT2KG or using " + "https://build.nvidia.com/nvidia/llama-3_1-nemotron-70b-instruct " "to deploy to a private endpoint, which you can pass to this script " "w/ --ENDPOINT_URL flag.") - print( - "Guide for deploying NIM: https://developer.nvidia.com/blog/a-simple-guide-to-deploying-generative-ai-with-nvidia-nim/" # noqa - ) + print("Guide for deploying NIM: " + "https://developer.nvidia.com/blog/" + "a-simple-guide-to-deploying-generative-ai-with-nvidia-nim/") total_tqdm_count = len(context_docs) initial_tqdm_count = 0 checkpoint_file = list(Path(args.dataset).glob("*--*--checkpoint_kg.pt")) @@ -326,11 +417,16 @@ def index_kg(args, context_docs): if len(checkpoint_file) == 1: print("Restoring KG from checkpoint") checkpoint_file = checkpoint_file[0] - checkpoint_model_name = checkpoint_file.name.split('--')[0] + checkpoint_model_name = checkpoint_file.name.split("--")[0] # check if triples generation are using the correct model - if args.NV_NIM_MODEL.split('/')[-1] != checkpoint_model_name: - raise RuntimeError( - "Error: stored triples were generated using a different model") + if args.use_local_txt2kg: + if checkpoint_model_name != "local": + raise RuntimeError( + "Error: stored triples generated using a different model") + else: + if args.NV_NIM_MODEL.split("/")[-1] != checkpoint_model_name: + raise RuntimeError( + "Error: stored triples generated using a different model") saved_relevant_triples = torch.load(checkpoint_file, weights_only=False) kg_maker.relevant_triples = saved_relevant_triples @@ -340,16 +436,19 @@ def index_kg(args, context_docs): chkpt_interval = 10 chkpt_count = 0 - for context_doc in tqdm(context_docs, total=total_tqdm_count, - initial=initial_tqdm_count, - desc="Extracting KG triples"): + for context_doc in tqdm( + context_docs, + total=total_tqdm_count, + initial=initial_tqdm_count, + desc="Extracting KG triples", + ): kg_maker.add_doc_2_KG(txt=context_doc) chkpt_count += 1 if chkpt_count == chkpt_interval: chkpt_count = 0 path = args.dataset + "/{m}--{t}--checkpoint_kg.pt" - model = kg_maker.NIM_MODEL.split( - '/')[-1] if not kg_maker.local_LM else "local" + model = (kg_maker.NIM_MODEL.split("/")[-1] + if not kg_maker.local_LM else "local") path = path.format(m=model, t=datetime.now().strftime("%Y%m%d_%H%M%S")) torch.save(kg_maker.relevant_triples, path) @@ -361,12 +460,13 @@ def index_kg(args, context_docs): triples = list(dict.fromkeys(triples)) raw_triples_path = args.dataset + "/{m}--{t}--raw_triples.pt" - model_name = kg_maker.NIM_MODEL.split( - '/')[-1] if not kg_maker.local_LM else "local" + model_name = (kg_maker.NIM_MODEL.split("/")[-1] + if not kg_maker.local_LM else "local") torch.save( triples, raw_triples_path.format(m=model_name, - t=datetime.now().strftime("%Y%m%d_%H%M%S"))) + t=datetime.now().strftime("%Y%m%d_%H%M%S")), + ) for old_checkpoint_file in Path( args.dataset).glob("*--*--checkpoint_kg.pt"): @@ -379,8 +479,8 @@ def update_data_lists(args, data_lists): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # creating the embedding model sent_trans_batch_size = 256 - model = SentenceTransformer( - model_name=ENCODER_MODEL_NAME_DEFAULT).to(device).eval() + model = (SentenceTransformer( + model_name=ENCODER_MODEL_NAME_DEFAULT).to(device).eval()) model_kwargs = { "output_device": device, "batch_size": int(sent_trans_batch_size / 4), @@ -435,11 +535,17 @@ def make_dataset(args): raise RuntimeError("Error: multiple raw_triples files found") if len(raw_triples_file) == 1: raw_triples_file = raw_triples_file[0] - stored_model_name = raw_triples_file.name.split('--')[0] + stored_model_name = raw_triples_file.name.split("--")[0] - if args.NV_NIM_MODEL.split('/')[-1] != stored_model_name: - raise RuntimeError( - "Error: stored triples were generated using a different model") + # Check if stored triples match current model configuration + if args.use_local_txt2kg: + if stored_model_name != "local": + raise RuntimeError( + "Error: stored triples generated using a different model") + else: + if args.NV_NIM_MODEL.split("/")[-1] != stored_model_name: + raise RuntimeError( + "Error: stored triples generated using a different model") print(f" -> Saved triples generated with: {stored_model_name}") triples = torch.load(raw_triples_file) @@ -457,18 +563,23 @@ def make_dataset(args): print("Creating the graph data from raw triples...") # create the graph data from raw triples graph_data = create_graph_from_triples( - triples=triples, embedding_model=model.encode, + triples=triples, + embedding_model=model.encode, embedding_method_kwargs={ "batch_size": min(len(triples), sent_trans_batch_size), - "verbose": True - }, pre_transform=preprocess_triplet) + "verbose": True, + }, + pre_transform=preprocess_triplet, + ) print("Creating the graph and feature stores...") # creating the graph and feature stores fs, gs = create_remote_backend_from_graph_data( - graph_data=graph_data, path="backend", + graph_data=graph_data, + path="backend", graph_db=NeighborSamplingRAGGraphStore, - feature_db=KNNRAGFeatureStore).load() + feature_db=KNNRAGFeatureStore, + ).load() """ NOTE: these retriever hyperparams are very important. Tuning may be needed for custom data... @@ -477,7 +588,7 @@ def make_dataset(args): model_kwargs = { "output_device": device, "batch_size": int(sent_trans_batch_size / 4), - "verbose": True + "verbose": True, } doc_retriever_path = os.path.join(args.dataset, "document_retriever.pt") @@ -490,10 +601,12 @@ def make_dataset(args): vector_retriever.k_for_docs = args.k_for_docs else: print("Creating document retriever...") - vector_retriever = DocumentRetriever(context_docs, - k_for_docs=args.k_for_docs, - model=model.encode, - model_kwargs=model_kwargs) + vector_retriever = DocumentRetriever( + context_docs, + k_for_docs=args.k_for_docs, + model=model.encode, + model_kwargs=model_kwargs, + ) vector_retriever.save(doc_retriever_path) subgraph_filter = make_pcst_filter( @@ -501,8 +614,9 @@ def make_dataset(args): model, topk=5, # nodes topk_e=5, # edges - cost_e=.5, # edge cost - num_clusters=10) # num clusters + cost_e=0.5, # edge cost + num_clusters=10, + ) # num clusters # number of neighbors for each seed node selected by KNN fanout = 100 @@ -521,10 +635,12 @@ def make_dataset(args): print("Now to retrieve context for each query from " "our Vector and Graph DBs...") - query_loader = RAGQueryLoader(graph_data=(fs, gs), - subgraph_filter=subgraph_filter, - vector_retriever=vector_retriever, - config=query_loader_config) + query_loader = RAGQueryLoader( + graph_data=(fs, gs), + subgraph_filter=subgraph_filter, + vector_retriever=vector_retriever, + config=query_loader_config, + ) # pre-process the dataset total_data_list = [] @@ -547,15 +663,17 @@ def make_dataset(args): # stats print("Min # of Retrieved Triples =", min(extracted_triple_sizes)) print("Max # of Retrieved Triples =", max(extracted_triple_sizes)) - print("Average # of Retrieved Triples =", - sum(extracted_triple_sizes) / len(extracted_triple_sizes)) + print( + "Average # of Retrieved Triples =", + sum(extracted_triple_sizes) / len(extracted_triple_sizes), + ) # 60:20:20 split - data_lists["train"] = total_data_list[:int(.6 * len(total_data_list))] - data_lists["validation"] = total_data_list[int(.6 * len(total_data_list) - ):int(.8 * + data_lists["train"] = total_data_list[:int(0.6 * len(total_data_list))] + data_lists["validation"] = total_data_list[int(0.6 * len(total_data_list) + ):int(0.8 * len(total_data_list))] - data_lists["test"] = total_data_list[int(.8 * len(total_data_list)):] + data_lists["test"] = total_data_list[int(0.8 * len(total_data_list)):] dataset_name = os.path.basename(args.dataset) dataset_path = os.path.join(args.dataset, f"{dataset_name}.pt") @@ -568,36 +686,87 @@ def make_dataset(args): def train(args, train_loader, val_loader): if args.wandb: - wandb.init(project=args.wandb_project, - name=f"run_{datetime.now().strftime('%Y-%m-%d_%H:%M')}", - config=vars(args)) + wandb.init( + project=args.wandb_project, + name=f"run_{datetime.now().strftime('%Y-%m-%d_%H:%M')}", + config=vars(args), + ) hidden_channels = args.gnn_hidden_channels num_gnn_layers = args.num_gnn_layers + def _uncertainty_kwargs(): + return dict( + uncertainty_estim=True, + uncertainty_cfg={ + "h_star": 0.05, + "isr_threshold": 1.0, + "m": 6, + "n_samples": 3, + "B_clip": 12.0, + "clip_mode": "one-sided", + "skeleton_policy": "auto", + "q_floor": None, + "temperature": 0.5, + "max_tokens_decision": 8, + "backend": "hf", # or "ollama" / "anthropic" + "mask_refusals_in_loss": True, + }, + decision_backend_kwargs=dict( + # HF backend (default): reuse the same model id; using a + # separate pipeline under the hood. + # If using Ollama or Anthropic instead, pass those + # credentials/args here. + ), + ) + if args.num_gnn_layers > 0: if args.gnn_model == "GAT": - gnn = GAT(in_channels=768, hidden_channels=hidden_channels, - out_channels=1024, num_layers=num_gnn_layers, heads=4) + gnn = GAT( + in_channels=768, + hidden_channels=hidden_channels, + out_channels=1024, + num_layers=num_gnn_layers, + heads=4, + ) elif args.gnn_model == "SGFormer": - gnn = SGFormer(in_channels=768, hidden_channels=hidden_channels, - out_channels=1024, trans_num_heads=1, - trans_dropout=0.5, gnn_num_layers=num_gnn_layers, - gnn_dropout=0.5) + gnn = SGFormer( + in_channels=768, + hidden_channels=hidden_channels, + out_channels=1024, + trans_num_heads=1, + trans_dropout=0.5, + gnn_num_layers=num_gnn_layers, + gnn_dropout=0.5, + ) else: raise ValueError(f"Invalid GNN model: {args.gnn_model}") else: gnn = None if args.llm_generator_mode == "full": - llm = LLM(model_name=args.llm_generator_name, sys_prompt=sys_prompt, - n_gpus=args.num_gpus) + llm = LLM( + model_name=args.llm_generator_name, + sys_prompt=sys_prompt, + n_gpus=args.num_gpus, + **_uncertainty_kwargs(), + ) elif args.llm_generator_mode == "lora": - llm = LLM(model_name=args.llm_generator_name, sys_prompt=sys_prompt, - dtype=torch.float32, n_gpus=args.num_gpus) + llm = LLM( + model_name=args.llm_generator_name, + sys_prompt=sys_prompt, + dtype=torch.float32, + n_gpus=args.num_gpus, + **_uncertainty_kwargs(), + ) else: # frozen - llm = LLM(model_name=args.llm_generator_name, sys_prompt=sys_prompt, - dtype=torch.float32, n_gpus=args.num_gpus).eval() + llm = LLM( + model_name=args.llm_generator_name, + sys_prompt=sys_prompt, + dtype=torch.float32, + n_gpus=args.num_gpus, + **_uncertainty_kwargs(), + ).eval() for _, p in llm.named_parameters(): p.requires_grad = False @@ -617,17 +786,20 @@ def train(args, train_loader, val_loader): else: params = [p for _, p in model.named_parameters() if p.requires_grad] lr = args.lr - optimizer = torch.optim.AdamW([{ - 'params': params, - 'lr': lr, - 'weight_decay': 0.05 - }], betas=(0.9, 0.95)) + optimizer = torch.optim.AdamW( + [{ + "params": params, + "lr": lr, + "weight_decay": 0.05 + }], + betas=(0.9, 0.95), + ) num_oom_errors = 0 for epoch in range(args.epochs): model.train() epoch_loss = 0 - epoch_str = f'Epoch: {epoch + 1}|{args.epochs}' + epoch_str = f"Epoch: {epoch + 1}|{args.epochs}" loader = tqdm(train_loader, desc=epoch_str) for step, batch in enumerate(loader): new_qs = [] @@ -636,7 +808,8 @@ def train(args, train_loader, val_loader): new_qs.append( prompt_template.format( question=q, - context="\n".join(batch.text_context[i]))) + context="\n".join(batch.text_context[i]), + )) batch.question = new_qs if args.skip_graph_rag: @@ -646,37 +819,44 @@ def train(args, train_loader, val_loader): try: loss = get_loss(model, batch) loss.backward() - clip_grad_norm_(optimizer.param_groups[0]['params'], 0.1) + clip_grad_norm_(optimizer.param_groups[0]["params"], 0.1) if (step + 1) % 2 == 0: - adjust_learning_rate(optimizer.param_groups[0], lr, - step / len(train_loader) + epoch, - args.epochs) + adjust_learning_rate( + optimizer.param_groups[0], + lr, + step / len(train_loader) + epoch, + args.epochs, + ) optimizer.step() epoch_loss += float(loss.detach()) if args.wandb and (step + 1) % args.log_steps == 0: wandb.log({ "train/loss": float(loss.detach()), - "train/lr": optimizer.param_groups[0]['lr'], + "train/lr": optimizer.param_groups[0]["lr"], }) if (step + 1) % 2 == 0: - lr = optimizer.param_groups[0]['lr'] + lr = optimizer.param_groups[0]["lr"] except torch.cuda.OutOfMemoryError: torch.cuda.empty_cache() - print("Sequence length of last batch: ", - model.seq_length_stats[-1]) + print( + "Sequence length of last batch: ", + model.seq_length_stats[-1], + ) # TODO: Implement CPU fallback (WIP) num_oom_errors += 1 print("Sequence length stats: ") - print("seq_len avg: ", - sum(model.seq_length_stats) / len(model.seq_length_stats)) + print( + "seq_len avg: ", + sum(model.seq_length_stats) / len(model.seq_length_stats), + ) print("seq_len min: ", min(model.seq_length_stats)) print("seq_len max: ", max(model.seq_length_stats)) print("Percent of OOM errors: ", num_oom_errors / len(train_loader)) train_loss = epoch_loss / len(train_loader) - print(epoch_str + f', Train Loss: {train_loss:4f}') + print(epoch_str + f", Train Loss: {train_loss:4f}") # Eval Step val_loss = 0 @@ -689,7 +869,8 @@ def train(args, train_loader, val_loader): new_qs.append( prompt_template.format( question=q, - context="\n".join(batch.text_context[i]))) + context="\n".join(batch.text_context[i]), + )) batch.question = new_qs if args.skip_graph_rag: batch.desc = "" @@ -702,7 +883,7 @@ def train(args, train_loader, val_loader): wandb.log({ "val/loss": val_loss, "train/epoch_loss": train_loss, - "epoch": epoch + 1 + "epoch": epoch + 1, }) if args.wandb: @@ -736,8 +917,8 @@ def eval(question: str, pred: str, correct_answer: str): test_batch.question = new_qs if args.skip_graph_rag: test_batch.desc = "" - preds = (inference_step(model, test_batch, - max_out_tokens=max_chars_in_train_answer / 2)) + preds = inference_step(model, test_batch, + max_out_tokens=max_chars_in_train_answer / 2) for question, pred, label in zip(raw_qs, preds, test_batch.label): eval_tuples.append((question, pred, label)) for question, pred, label in tqdm(eval_tuples, desc="Eval"): @@ -749,7 +930,7 @@ def eval(question: str, pred: str, correct_answer: str): print("Improvement of this estimation process is WIP...") -if __name__ == '__main__': +if __name__ == "__main__": # for reproducibility seed_everything(50) @@ -780,13 +961,27 @@ def eval(question: str, pred: str, correct_answer: str): data_lists = make_dataset(args) batch_size = args.batch_size eval_batch_size = args.eval_batch_size - train_loader = DataLoader(data_lists["train"], batch_size=batch_size, - drop_last=True, pin_memory=True, shuffle=True) - val_loader = DataLoader(data_lists["validation"], - batch_size=eval_batch_size, drop_last=False, - pin_memory=True, shuffle=False) - test_loader = DataLoader(data_lists["test"], batch_size=eval_batch_size, - drop_last=False, pin_memory=True, shuffle=False) + train_loader = DataLoader( + data_lists["train"], + batch_size=batch_size, + drop_last=True, + pin_memory=True, + shuffle=True, + ) + val_loader = DataLoader( + data_lists["validation"], + batch_size=eval_batch_size, + drop_last=False, + pin_memory=True, + shuffle=False, + ) + test_loader = DataLoader( + data_lists["test"], + batch_size=eval_batch_size, + drop_last=False, + pin_memory=True, + shuffle=False, + ) model = train(args, train_loader, val_loader) test(model, test_loader, args) diff --git a/test/llm/models/test_llm.py b/test/llm/models/test_llm.py index a6a2608551a8..07283775e73a 100644 --- a/test/llm/models/test_llm.py +++ b/test/llm/models/test_llm.py @@ -10,8 +10,8 @@ @onlyRAG @withPackage('transformers', 'accelerate') def test_llm() -> None: - question = ["Is PyG the best open-source GNN library?"] - answer = ["yes!"] + question = ['Is PyG the best open-source GNN library?'] + answer = ['yes!'] model = LLM( model_name='HuggingFaceTB/SmolLM-360M', @@ -30,3 +30,58 @@ def test_llm() -> None: del model gc.collect() torch.cuda.empty_cache() + + +@onlyRAG +@withPackage('transformers', 'accelerate') +def test_llm_uncertainty() -> None: + """Test LLM with uncertainty estimation.""" + question = ['Who won the Nobel Prize in Physics in 2050?'] + + # Test with uncertainty enabled + model = LLM( + model_name='HuggingFaceTB/SmolLM-360M', + num_params=1, + dtype=torch.float16, + uncertainty_estim=True, + uncertainty_cfg={ + 'h_star': 0.05, + 'isr_threshold': 1.0, + 'm': 3, + 'n_samples': 2, + 'B_clip': 12.0, + 'clip_mode': 'one-sided', + 'skeleton_policy': 'auto', + 'temperature': 0.5, + 'max_tokens_decision': 8, + 'backend': 'hf', + 'mask_refusals_in_loss': True, + }, + ) + + assert model.uncertainty_estim is True + + # Test inference with uncertainty + result = model.inference( + question=question, + context=[''], + max_tokens=10, + return_uncertainty=True, + ) + + # Should return tuple of (texts, uncertainties) + assert isinstance(result, tuple) + texts, uncertainties = result + assert len(texts) == 1 + assert len(uncertainties) == 1 + + # Check uncertainty object has expected attributes + uncertainty = uncertainties[0] + assert hasattr(uncertainty, 'isr') + assert hasattr(uncertainty, 'decision_answer') + assert isinstance(uncertainty.isr, float) + assert isinstance(uncertainty.decision_answer, bool) + + del model + gc.collect() + torch.cuda.empty_cache() diff --git a/torch_geometric/llm/models/llm.py b/torch_geometric/llm/models/llm.py index 1a3b2d22eb7b..8596046073ee 100644 --- a/torch_geometric/llm/models/llm.py +++ b/torch_geometric/llm/models/llm.py @@ -1,6 +1,9 @@ +import json +import re import warnings from contextlib import nullcontext -from typing import Any, Dict, List, Optional +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import Tensor @@ -13,13 +16,50 @@ IGNORE_INDEX = -100 MAX_TXT_LEN = 512 MAX_NEW_TOKENS = 128 + + +# ---- Uncertainty (EDFL/B2T/ISR) glue ------------------------------------ +@dataclass +class _Uncertainty: + decision_answer: bool + delta_bar: float + b2t: float + isr: float + roh_bound: float + q_avg: float + q_conservative: float + y_label: str + rationale: str + + +def _parse_decision_json(text: str) -> str: + # Compatible with your toolkit's tiny-JSON head. (answer|refuse) + try: + obj = json.loads(text) + d = str(obj.get("decision", "")).strip().lower() + if d in ("answer", "refuse"): + return d + except Exception: + pass + m = re.search(r'{"\s*decision\s*"\s*:\s*"(answer|refuse)"}', text, + flags=re.I) + if m: + return m.group(1).lower() + t = text.strip().lower() + if "refuse" in t and "answer" not in t: + return "refuse" + if "answer" in t and "refuse" not in t: + return "answer" + return "refuse" # default-safe + + PAD_TOKEN_ID = 0 -PADDING_SIDE = 'left' +PADDING_SIDE = "left" # legacy constants - used for Llama 2 style prompting -BOS = '[INST]' -EOS_USER = '[/INST]' -EOS = '[/s]' +BOS = "[INST]" +EOS_USER = "[/INST]" +EOS = "[/s]" def get_llm_kwargs(required_memory: int, dtype=torch.dtype) -> Dict[str, Any]: @@ -35,15 +75,15 @@ def get_llm_kwargs(required_memory: int, dtype=torch.dtype) -> Dict[str, Any]: if sum(gpu_memory) < required_memory: gpu_memory = [] # If not enough VRAM, use pure CPU. - kwargs = dict(revision='main') + kwargs = dict(revision="main") if len(gpu_memory) > 0: - kwargs['max_memory'] = { - i: f'{memory}GiB' + kwargs["max_memory"] = { + i: f"{memory}GiB" for i, memory in enumerate(gpu_memory) } - kwargs['low_cpu_mem_usage'] = True - kwargs['device_map'] = 'auto' - kwargs['torch_dtype'] = dtype + kwargs["low_cpu_mem_usage"] = True + kwargs["device_map"] = "auto" + kwargs["torch_dtype"] = dtype return kwargs @@ -74,15 +114,21 @@ def __init__( n_gpus: Optional[int] = None, dtype: Optional[torch.dtype] = torch.bfloat16, sys_prompt: Optional[str] = None, + # --- new --- + uncertainty_estim: bool = False, + uncertainty_cfg: Optional[Dict[str, Any]] = None, + decision_backend_kwargs: Optional[Dict[str, Any]] = None, ) -> None: super().__init__() self.model_name = model_name from transformers import AutoModelForCausalLM, AutoTokenizer + if n_gpus is None: if num_params is None: from huggingface_hub import get_safetensors_metadata + safetensors_metadata = get_safetensors_metadata(model_name) param_count = safetensors_metadata.parameter_count num_params = float(list(param_count.values())[0] // 10**9) @@ -95,14 +141,14 @@ def __init__( gpu_memory: List[int] = [] for i in range(n_gpus): gpu_memory.append(torch.cuda.mem_get_info(i)[0] // 1024**3) - kwargs = dict(revision='main') - kwargs['max_memory'] = { - i: f'{memory}GiB' + kwargs = dict(revision="main") + kwargs["max_memory"] = { + i: f"{memory}GiB" for i, memory in enumerate(gpu_memory) } - kwargs['low_cpu_mem_usage'] = True - kwargs['device_map'] = 'auto' - kwargs['torch_dtype'] = dtype + kwargs["low_cpu_mem_usage"] = True + kwargs["device_map"] = "auto" + kwargs["torch_dtype"] = dtype print(f"Setting up '{model_name}' with configuration: {kwargs}") self.tokenizer = AutoTokenizer.from_pretrained( @@ -135,22 +181,158 @@ def __init__( self.sys_prompt = sys_prompt else: self.sys_prompt = "" - if 'max_memory' not in kwargs: # Pure CPU: + if "max_memory" not in kwargs: # Pure CPU: warnings.warn( "LLM is being used on CPU, which may be slow. This decision " "was made by a rough hueristic that assumes your GPU set up " "does not have enough GPU RAM. This is done to avoid GPU OOM " "errors. If you think this is a mistake, please initialize " "your LLM with the n_gpus param to dictate how many gpus to " - "use for the LLM.", stacklevel=2) - self.device = torch.device('cpu') + "use for the LLM.", + stacklevel=2, + ) + self.device = torch.device("cpu") self.autocast_context = nullcontext() else: self.device = self.llm.device if dtype == torch.float32: self.autocast_context = nullcontext() else: - self.autocast_context = torch.amp.autocast('cuda', dtype=dtype) + self.autocast_context = torch.amp.autocast("cuda", dtype=dtype) + + # ---- Uncertainty config / backends (opt-in) ---------------------- + self.uncertainty_estim = bool(uncertainty_estim) + self._uncfg: Dict[str, Any] = { + "h_star": 0.05, + "isr_threshold": 1.0, + "m": 6, + "n_samples": 3, + "B_clip": 12.0, + "clip_mode": "one-sided", + "skeleton_policy": "auto", + "q_floor": None, + "temperature": 0.5, + "max_tokens_decision": 8, + "backend": "hf", # "hf" | "ollama" | "anthropic" + "mask_refusals_in_loss": True, + } + if uncertainty_cfg: + self._uncfg.update(uncertainty_cfg) + self._dec_backend_kwargs = decision_backend_kwargs or {} + self._last_uncertainty: Optional[List[_Uncertainty]] = None + + # Try to import your planner/backends; fall back to a local adapter: + self._planner = None + self._backend = None + if self.uncertainty_estim: + try: + from hallbayes import OpenAIItem, OpenAIPlanner + + self._OpenAIPlanner = OpenAIPlanner + self._OpenAIItem = OpenAIItem + try: + from hallbayes import ( + AnthropicBackend, + HuggingFaceBackend, + OllamaBackend, + ) + + self._HFBackend = HuggingFaceBackend + self._OllamaBackend = OllamaBackend + self._AnthropicBackend = AnthropicBackend + except Exception: + self._HFBackend = self._OllamaBackend = ( + self._AnthropicBackend) = None + except Exception: + self.uncertainty_estim = False + warnings.warn( + "Uncertainty requested but `hallucination_toolkit` not " + "importable. Proceeding without uncertainty.", + stacklevel=2, + ) + + # ---- helper: turn (question, context) into the string for decision --- + def _decision_user_text(self, q: str, ctx: Optional[str]) -> str: + if ctx and len(ctx) > 0: + return f"{ctx} - {q}" + return q + + # ---- build the decision backend on first use ------------------------- + def _get_decision_backend(self): + if self._backend is not None: + return self._backend + if not self.uncertainty_estim: + return None + backend_kind = str(self._uncfg.get("backend", "hf")).lower() + if (backend_kind == "hf" + and getattr(self, "_HFBackend", None) is not None): + self._backend = self._HFBackend( + mode="transformers", + model_id=self.model_name, + device_map="auto", + trust_remote_code=True, + **self._dec_backend_kwargs, + ) + elif (backend_kind == "ollama" + and getattr(self, "_OllamaBackend", None) is not None): + self._backend = self._OllamaBackend(**self._dec_backend_kwargs) + elif (backend_kind == "anthropic" + and getattr(self, "_AnthropicBackend", None) is not None): + self._backend = self._AnthropicBackend(**self._dec_backend_kwargs) + else: + self._backend = None + return self._backend + + # ---- compute EDFL/B2T/ISR metrics per item --------------------------- + def _compute_uncertainty( + self, + questions: List[str], + context: Optional[List[str]] = None, + ) -> List[_Uncertainty]: + backend = self._get_decision_backend() + Planner = getattr(self, "_OpenAIPlanner", None) + Item = getattr(self, "_OpenAIItem", None) + if backend is None or Planner is None or Item is None: + return [] + planner = Planner( + backend=backend, + temperature=float(self._uncfg["temperature"]), + max_tokens_decision=int(self._uncfg["max_tokens_decision"]), + q_floor=self._uncfg["q_floor"], + ) + items = [] + for i, q in enumerate(questions): + ctx_i = None if context is None else context[i] + user_txt = self._decision_user_text(q, ctx_i) + items.append( + Item( + prompt=user_txt, + n_samples=int(self._uncfg["n_samples"]), + m=int(self._uncfg["m"]), + skeleton_policy=str(self._uncfg["skeleton_policy"]), + )) + metrics = planner.run( + items, + h_star=float(self._uncfg["h_star"]), + isr_threshold=float(self._uncfg["isr_threshold"]), + B_clip=float(self._uncfg["B_clip"]), + clip_mode=str(self._uncfg["clip_mode"]), + ) + outs: List[_Uncertainty] = [] + for m in metrics: + outs.append( + _Uncertainty( + decision_answer=bool(m.decision_answer), + delta_bar=float(m.delta_bar), + b2t=float(m.b2t), + isr=float(m.isr), + roh_bound=float(m.roh_bound), + q_avg=float(m.q_avg), + q_conservative=float(m.q_conservative), + y_label=str(m.meta.get("y_label", "") if m.meta else ""), + rationale=str(m.rationale), + )) + return outs # legacy function - used for Llama 2 style prompting def _encode_inputs( @@ -164,17 +346,23 @@ def _encode_inputs( context = self.tokenizer(context, add_special_tokens=False) eos_user_tokens = self.tokenizer(EOS_USER, add_special_tokens=False) - bos_token = self.tokenizer( + bos_token = (self.tokenizer( BOS, add_special_tokens=False, - return_tensors='pt', - ).input_ids[0].to(self.device) + return_tensors="pt", + ).input_ids[0].to(self.device)) bos_embeds = self.word_embedding(bos_token) pad_token = torch.tensor(self.tokenizer.pad_token_id, device=self.device) pad_embeds = self.word_embedding(pad_token).unsqueeze(0) - return (batch_size, questions, context, eos_user_tokens, bos_embeds, - pad_embeds) + return ( + batch_size, + questions, + context, + eos_user_tokens, + bos_embeds, + pad_embeds, + ) def _label_input_ids( self, @@ -269,8 +457,14 @@ def _get_embeds_old( embedding: Optional[List[Tensor]] = None, answer: Optional[List[str]] = None, ) -> tuple: - (batch_size, question, context, eos_user_tokens, bos_embeds, - pad_embeds) = self._encode_inputs(question, context) + ( + batch_size, + question, + context, + eos_user_tokens, + bos_embeds, + pad_embeds, + ) = self._encode_inputs(question, context) batch_label_input_ids = None if answer is not None: @@ -304,8 +498,11 @@ def _get_embeds_old( ) inputs_embeds, attention_mask, label_input_ids = self._pad_embeds( - pad_embeds, batch_inputs_embeds, batch_attention_mask, - batch_label_input_ids) + pad_embeds, + batch_inputs_embeds, + batch_attention_mask, + batch_label_input_ids, + ) return inputs_embeds, attention_mask, label_input_ids @@ -321,7 +518,9 @@ def _get_embeds( f"HuggingFace model {self.model_name} is not using a " "chat template, using Llama 2 style prompting. Please " "consider using a more recent model and initialize the " - "LLM with `sys_prompt`.", stacklevel=2) + "LLM with `sys_prompt`.", + stacklevel=2, + ) return self._get_embeds_old(question, context, embedding, answer) batch_label_input_ids = None if answer is not None: @@ -359,11 +558,11 @@ def _get_embeds( else: label_input_ids = None - bos_token = self.tokenizer( + bos_token = (self.tokenizer( self.tokenizer.bos_token, add_special_tokens=False, - return_tensors='pt', - ).input_ids[0].to(self.device) + return_tensors="pt", + ).input_ids[0].to(self.device)) bos_embeds = self.word_embedding(bos_token) @@ -393,8 +592,11 @@ def _get_embeds( pad_embeds = self.word_embedding(pad_token).unsqueeze(0) inputs_embeds, attention_mask, label_input_ids = self._pad_embeds( - pad_embeds, batch_inputs_embeds, batch_attention_mask, - batch_label_input_ids) + pad_embeds, + batch_inputs_embeds, + batch_attention_mask, + batch_label_input_ids, + ) return inputs_embeds, attention_mask, label_input_ids @@ -420,6 +622,16 @@ def forward( inputs_embeds, attention_mask, label_input_ids = self._get_embeds( question, context, embedding, answer) + # --- Uncertainty-aware training: mask labels for ISR<1 ------------- + if self.uncertainty_estim and label_input_ids is not None: + unc = self._compute_uncertainty(question, context) + self._last_uncertainty = unc + if self._uncfg.get("mask_refusals_in_loss", + True) and len(unc) == len(question): + for i, u in enumerate(unc): + if not u.decision_answer: + label_input_ids[i, :] = IGNORE_INDEX + with self.autocast_context: outputs = self.llm( inputs_embeds=inputs_embeds, @@ -436,12 +648,14 @@ def inference( context: Optional[List[str]] = None, embedding: Optional[List[Tensor]] = None, max_tokens: Optional[int] = MAX_NEW_TOKENS, - ) -> List[str]: + # --- new --- + return_uncertainty: bool = False, + abstain_on_low_ISR: bool = False, + ) -> Union[List[str], Tuple[List[str], List[_Uncertainty]]]: r"""The inference pass. Args: question (list[str]): The questions/prompts. - answer (list[str]): The answers/labels. context (list[str], optional): Additional context to give to the LLM, such as textified knowledge graphs. (default: :obj:`None`) embedding (list[torch.Tensor], optional): RAG embedding @@ -450,10 +664,21 @@ def inference( both. (default: :obj:`None`) max_tokens (int, optional): How many tokens for the LLM to generate. (default: :obj:`32`) + return_uncertainty (bool, optional): Whether to also return the + uncertainty estimates computed for each item. (default: + :obj:`False`) + abstain_on_low_ISR (bool, optional): Whether to replace outputs + with ``"[ABSTAIN]"`` when the ISR-driven decision flags the + answer as a refusal. (default: :obj:`False`) """ inputs_embeds, attention_mask, _ = self._get_embeds( question, context, embedding) + unc: List[_Uncertainty] = [] + if self.uncertainty_estim: + unc = self._compute_uncertainty(question, context) + self._last_uncertainty = unc + with self.autocast_context: outputs = self.llm.generate( inputs_embeds=inputs_embeds, @@ -464,7 +689,15 @@ def inference( use_cache=True, ) - return self.tokenizer.batch_decode(outputs, skip_special_tokens=True) + texts = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) + if self.uncertainty_estim and unc and abstain_on_low_ISR: + texts = [ + t if u.decision_answer else "[ABSTAIN]" + for t, u in zip(texts, unc) + ] + if return_uncertainty and (self.uncertainty_estim and unc): + return texts, unc + return texts def __repr__(self) -> str: - return f'{self.__class__.__name__}({self.model_name})' + return f"{self.__class__.__name__}({self.model_name})" diff --git a/torch_geometric/llm/models/txt2kg.py b/torch_geometric/llm/models/txt2kg.py index ac011b15d248..f94143157b35 100644 --- a/torch_geometric/llm/models/txt2kg.py +++ b/torch_geometric/llm/models/txt2kg.py @@ -43,20 +43,23 @@ class TXT2KG(): than deploying your own private NIM endpoint. This flag is mainly recommended for dev/debug. (default: False). + local_LM_model_name : str, optional + Model name to load when `local_LM` is enabled. + (default: "VAGOsolutions/SauerkrautLM-v2-14b-DPO"). chunk_size : int, optional The size of the chunks in which the text data is processed (default: 512). """ - def __init__( - self, - NVIDIA_NIM_MODEL: Optional[ - str] = "nvidia/llama-3.1-nemotron-70b-instruct", - NVIDIA_API_KEY: Optional[str] = "", - ENDPOINT_URL: Optional[str] = "https://integrate.api.nvidia.com/v1", - local_LM: bool = False, - chunk_size: int = 512, - ) -> None: + def __init__(self, NVIDIA_NIM_MODEL: Optional[ + str] = "nvidia/llama-3.1-nemotron-70b-instruct", + NVIDIA_API_KEY: Optional[str] = "", ENDPOINT_URL: Optional[ + str] = "https://integrate.api.nvidia.com/v1", + local_LM: bool = False, local_LM_model_name: Optional[ + str] = "VAGOsolutions/SauerkrautLM-v2-14b-DPO", + chunk_size: int = 512, n_gpus=None) -> None: + self.n_gpus = n_gpus self.local_LM = local_LM + self.local_LM_model_name = local_LM_model_name # Initialize the local LM flag and the NIM model info accordingly if self.local_LM: # If using a local LM, set the initd_LM flag to False @@ -91,9 +94,12 @@ def _chunk_to_triples_str_local(self, txt: str) -> str: # call LLM on text chunk_start_time = time.time() if not self.initd_LM: - from torch_geometric.nn.nlp import LLM - LM_name = "VAGOsolutions/SauerkrautLM-v2-14b-DPO" - self.model = LLM(LM_name).eval() + from torch_geometric.llm.models import LLM + LM_name = self.local_LM_model_name + if self.n_gpus is not None: + self.model = LLM(LM_name, n_gpus=self.n_gpus).eval() + else: + self.model = LLM(LM_name).eval() self.initd_LM = True out_str = self.model.inference(question=[txt + '\n' + SYSTEM_PROMPT], max_tokens=self.chunk_size)[0]