Skip to content

Commit 2a120a6

Browse files
drivanovpre-commit-ci[bot]akihironitta
authored
Improvement of the txt2kg example. (#10623)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Akihiro Nitta <akihiro@kumo.ai>
1 parent fb44ee4 commit 2a120a6

File tree

8 files changed

+677
-67
lines changed

8 files changed

+677
-67
lines changed

test/llm/models/test_g_retriever.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import gc
2+
from contextlib import nullcontext
3+
from types import SimpleNamespace
24

5+
import pytest
36
import torch
7+
from torch import nn
48

59
from torch_geometric.llm.models import LLM, GRetriever
610
from torch_geometric.nn import GAT
@@ -100,3 +104,88 @@ def test_g_retriever_many_tokens() -> None:
100104
del model, llm, gnn
101105
gc.collect()
102106
torch.cuda.empty_cache()
107+
108+
109+
class DummyHFModel(nn.Module):
110+
def __init__(self, vocab_size=10):
111+
super().__init__()
112+
self.vocab_size = vocab_size
113+
self.dummy = nn.Parameter(torch.zeros(1))
114+
115+
def forward(self, inputs_embeds=None, **kwargs):
116+
B, T, _ = inputs_embeds.shape
117+
logits = torch.randn(B, T, self.vocab_size,
118+
device=inputs_embeds.device)
119+
loss = torch.tensor(0.0, device=inputs_embeds.device)
120+
loss.logits = logits
121+
return SimpleNamespace(
122+
logits=logits,
123+
loss=loss,
124+
)
125+
126+
127+
class DummyLLM:
128+
def __init__(self, hidden_dim):
129+
self.word_embedding = nn.Embedding(100, hidden_dim)
130+
self.llm = DummyHFModel()
131+
self.device = torch.device("cpu")
132+
self.autocast_context = nullcontext()
133+
134+
def _get_embeds(self, question, *args):
135+
batch_size = len(question)
136+
seq_len = 4
137+
hidden = self.word_embedding.embedding_dim
138+
139+
inputs_embeds = torch.randn(batch_size, seq_len, hidden)
140+
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long)
141+
142+
return inputs_embeds, attention_mask, None
143+
144+
145+
class DummyGNN(nn.Module):
146+
"""Simple GNN stub returning node embeddings."""
147+
def __init__(self, in_channels=4, out_channels=8):
148+
super().__init__()
149+
self.in_channels = in_channels
150+
self.out_channels = out_channels
151+
self.lin = nn.Linear(in_channels, out_channels)
152+
153+
def forward(self, *args, **kwargs):
154+
x = args[0]
155+
return self.lin(x)
156+
157+
158+
@pytest.mark.parametrize("batch_size", [1, 3])
159+
def test_gretriever_prefix_embedding_injection(batch_size):
160+
hidden_dim = 8
161+
num_nodes = 5
162+
163+
llm = DummyLLM(hidden_dim)
164+
gnn = DummyGNN(in_channels=4, out_channels=8)
165+
166+
model = GRetriever(
167+
llm=llm,
168+
gnn=gnn,
169+
mlp_out_tokens=2,
170+
)
171+
172+
# graph inputs
173+
x = torch.randn(num_nodes, 4)
174+
edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]])
175+
batch = torch.zeros(num_nodes, dtype=torch.long)
176+
177+
# token ids
178+
questions = ["What is this graph?"] * batch_size
179+
labels = ["dummy answer"] * batch_size
180+
181+
out = model(
182+
x=x,
183+
edge_index=edge_index,
184+
batch=batch,
185+
question=questions,
186+
label=labels,
187+
)
188+
189+
# basic correctness assertions
190+
assert hasattr(out, "logits")
191+
assert out.logits.shape[0] == batch_size

test/llm/models/test_llm.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import gc
22

3+
import pytest
34
import torch
45
from torch import Tensor
56

@@ -28,3 +29,96 @@ def test_llm() -> None:
2829
del model
2930
gc.collect()
3031
torch.cuda.empty_cache()
32+
33+
34+
class DummyBatch(dict):
35+
"""Mimics HuggingFace BatchEncoding."""
36+
def to(self, device):
37+
return self
38+
39+
40+
class DummyTokenizer:
41+
pad_token_id = 0
42+
padding_side = "left"
43+
44+
def __call__(self, texts, return_tensors=None, padding=True):
45+
lengths = [len(t) for t in texts]
46+
max_len = max(lengths)
47+
48+
ids = []
49+
mask = []
50+
51+
for seq_len in lengths:
52+
padding = max_len - seq_len
53+
ids.append([0] * padding + list(range(1, seq_len + 1)))
54+
mask.append([0] * padding + [1] * seq_len)
55+
56+
return DummyBatch({
57+
"input_ids": torch.tensor(ids),
58+
"attention_mask": torch.tensor(mask)
59+
})
60+
61+
62+
class DummyModel(torch.nn.Module):
63+
def get_input_embeddings(self):
64+
return torch.nn.Embedding(100, 8)
65+
66+
def forward(self, inputs_embeds=None, attention_mask=None, **kwargs):
67+
batch, seq, dim = inputs_embeds.shape
68+
69+
class Out:
70+
pass
71+
72+
out = Out()
73+
out.logits = torch.zeros(batch, seq, 10)
74+
return out
75+
76+
77+
@pytest.fixture
78+
def dummy_llm():
79+
llm = LLM.__new__(LLM)
80+
torch.nn.Module.__init__(llm)
81+
llm.device = torch.device("cpu")
82+
llm.tokenizer = DummyTokenizer()
83+
llm.model = DummyModel()
84+
return llm
85+
86+
87+
@onlyRAG
88+
def test_llm_prepare_inputs(dummy_llm):
89+
prompts = ["hello", "hi"]
90+
91+
encoded = dummy_llm.tokenizer(prompts)
92+
93+
input_ids = encoded["input_ids"]
94+
attention_mask = encoded["attention_mask"]
95+
96+
emb = dummy_llm.model.get_input_embeddings()
97+
inputs_embeds = emb(input_ids)
98+
99+
out = dummy_llm.model(inputs_embeds=inputs_embeds,
100+
attention_mask=attention_mask)
101+
102+
assert inputs_embeds.shape[0] == 2
103+
assert attention_mask.shape == input_ids.shape
104+
assert hasattr(out, "logits")
105+
assert out.logits.shape[:2] == inputs_embeds.shape[:2]
106+
107+
108+
@onlyRAG
109+
def test_llm_single_prompt(dummy_llm):
110+
encoded = dummy_llm.tokenizer(["test"])
111+
112+
assert encoded["input_ids"].shape[0] == 1
113+
114+
115+
@onlyRAG
116+
def test_llm_variable_lengths(dummy_llm):
117+
prompts = ["a", "abcdef", "abc"]
118+
119+
encoded = dummy_llm.tokenizer(prompts)
120+
121+
input_ids = encoded["input_ids"]
122+
123+
assert input_ids.shape[0] == 3
124+
assert input_ids.shape[1] == max(len(p) for p in prompts)

test/llm/models/test_sentence_transformer.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,26 +10,30 @@
1010
@pytest.mark.parametrize('batch_size', [None, 1])
1111
@pytest.mark.parametrize('pooling_strategy', ['mean', 'last', 'cls'])
1212
def test_sentence_transformer(batch_size, pooling_strategy, device):
13+
14+
model_name = 'bert-base-uncased'
1315
model = SentenceTransformer(
14-
model_name='prajjwal1/bert-tiny',
16+
model_name=model_name,
1517
pooling_strategy=pooling_strategy,
1618
).to(device)
1719
assert model.device == device
18-
assert str(model) == 'SentenceTransformer(model_name=prajjwal1/bert-tiny)'
20+
assert str(model) == f'SentenceTransformer(model_name={model_name})'
1921

2022
text = [
2123
"this is a basic english text",
2224
"PyG is the best open-source GNN library :)",
2325
]
2426

27+
model_embedding_dim = model.model.config.hidden_size
28+
2529
out = model.encode(text, batch_size=batch_size)
2630
assert out.device == device
27-
assert out.size() == (2, 128)
31+
assert out.shape == (2, model_embedding_dim)
2832

2933
out = model.encode(text, batch_size=batch_size, output_device='cpu')
3034
assert out.is_cpu
31-
assert out.size() == (2, 128)
35+
assert out.shape == (2, model_embedding_dim)
3236

3337
out = model.encode([], batch_size=batch_size)
3438
assert out.device == device
35-
assert out.size() == (0, 128)
39+
assert out.shape == (0, model_embedding_dim)

0 commit comments

Comments
 (0)