Skip to content

Commit c8f4517

Browse files
authored
Merge pull request #85 from boostcampwm-2024/refactor-be-#84
Refactor be #84
2 parents 8b03b64 + 9b2c3e5 commit c8f4517

File tree

6 files changed

+154
-14
lines changed

6 files changed

+154
-14
lines changed

apps/backend/src/langchain/langchain.service.ts

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,28 @@ export class LangchainService {
2222
async query(question: string) {
2323
const promptTemplate = await pull<ChatPromptTemplate>('rlm/rag-prompt');
2424
const queryEmbeddings = await embeddings.embedQuery(question);
25+
// const retrievedDocs = await this.dataSource.query(
26+
// `SELECT content FROM page ORDER BY embedding <=> '[${queryEmbeddings.join(',')}]' LIMIT 1;`,
27+
// );
2528
const retrievedDocs = await this.dataSource.query(
26-
`SELECT content FROM page ORDER BY embedding <=> '[${queryEmbeddings.join(',')}]' LIMIT 1;`,
29+
`
30+
select document from hybrid_search(
31+
'${question}',
32+
'[${queryEmbeddings.join(',')}]'::vector(384),
33+
1
34+
);
35+
`,
2736
);
2837
// const retrievedDocs = await this.vectorStore.similaritySearch(question, 1);
2938
retrievedDocs.forEach((doc) => {
30-
console.log(doc.content);
39+
console.log(doc.document);
3140
});
32-
const docsContent = retrievedDocs
33-
.map((doc) => {
34-
return this.extractTextValues(JSON.parse(JSON.stringify(doc.content)));
35-
})
36-
.join('\n');
41+
// const docsContent = retrievedDocs
42+
// .map((doc) => {
43+
// return this.extractTextValues(JSON.parse(JSON.stringify(doc.content)));
44+
// })
45+
// .join('\n');
46+
const docsContent = retrievedDocs.map((doc) => doc.document).join('\n');
3747

3848
const messages = await promptTemplate.invoke({
3949
question: question,

apps/backend/src/page/page.controller.spec.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ describe('PageController', () => {
6161
node: null,
6262
emoji: null,
6363
workspace: null,
64+
document: null,
65+
fts: null,
6466
});
6567

6668
const result = await controller.createPage(dto);
@@ -126,6 +128,8 @@ describe('PageController', () => {
126128
version: 1,
127129
emoji: null,
128130
workspace: null,
131+
document: null,
132+
fts: null,
129133
};
130134

131135
jest.spyOn(pageService, 'findPageById').mockResolvedValue(expectedPage);

apps/backend/src/page/page.entity.ts

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,12 @@ export class Page {
2020
@Column({ nullable: true })
2121
title: string;
2222

23-
@Column('json', { nullable: true }) //TODO: Postgres에서는 jsonb로 변경
23+
@Column('json', { nullable: true })
2424
content: JSON;
2525

26+
@Column({ nullable: true })
27+
document: string;
28+
2629
@CreateDateColumn()
2730
createdAt: Date;
2831

@@ -35,6 +38,14 @@ export class Page {
3538
@Column({ nullable: true })
3639
emoji: string | null;
3740

41+
@Column({
42+
generatedType: 'STORED',
43+
type: 'tsvector',
44+
asExpression: `to_tsvector('english', document)`,
45+
nullable: true,
46+
})
47+
fts: string;
48+
3849
@OneToOne(() => Node, (node) => node.page, {
3950
onDelete: 'CASCADE',
4051
})

apps/backend/src/page/page.repository.ts

Lines changed: 111 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,118 @@ export class PageRepository extends Repository<Page> implements OnModuleInit {
1111
CREATE EXTENSION IF NOT EXISTS vector
1212
`);
1313

14-
// vector 컬럼 추가
14+
// vector 컬럼 추가 (존재 여부 확인 후 추가)
1515
await this.dataSource.query(`
16-
ALTER TABLE page ADD COLUMN embedding vector(384);
17-
`);
16+
DO $$
17+
DECLARE
18+
column_exists BOOLEAN;
19+
index_exists BOOLEAN;
20+
BEGIN
21+
-- 컬럼 존재 여부 확인
22+
SELECT EXISTS (
23+
SELECT 1
24+
FROM pg_attribute
25+
WHERE attrelid = 'page'::regclass
26+
AND attname = 'embedding'
27+
) INTO column_exists;
28+
29+
IF NOT column_exists THEN
30+
-- 컬럼이 없으면 추가
31+
EXECUTE 'ALTER TABLE page ADD COLUMN embedding vector(384)';
32+
END IF;
33+
34+
-- 인덱스 존재 여부 확인
35+
SELECT EXISTS (
36+
SELECT 1
37+
FROM pg_indexes
38+
WHERE tablename = 'page'
39+
AND indexname = 'page_embedding_hnsw_idx'
40+
) INTO index_exists;
41+
42+
IF NOT index_exists THEN
43+
-- HNSW 인덱스 생성
44+
EXECUTE 'CREATE INDEX page_embedding_hnsw_idx ON page USING hnsw (embedding vector_ip_ops)';
45+
END IF;
46+
47+
-- GIN 인덱스 존재 여부 확인
48+
SELECT EXISTS (
49+
SELECT 1
50+
FROM pg_indexes
51+
WHERE tablename = 'page'
52+
AND indexname = 'page_fts_gin_idx'
53+
) INTO index_exists;
54+
55+
IF NOT index_exists THEN
56+
-- GIN 인덱스 생성
57+
EXECUTE 'CREATE INDEX page_fts_gin_idx ON page USING gin(fts)';
58+
END IF;
59+
END $$;
60+
61+
create or replace function hybrid_search(
62+
query_text text,
63+
query_embedding vector(512),
64+
match_count int,
65+
full_text_weight float = 1,
66+
semantic_weight float = 1,
67+
rrf_k int = 50
68+
)
69+
returns setof page
70+
language sql
71+
as $$
72+
with full_text as (
73+
select
74+
id,
75+
-- Note: ts_rank_cd is not indexable but will only rank matches of the where clause
76+
-- which shouldn't be too big
77+
row_number() over(order by ts_rank_cd(fts, websearch_to_tsquery(query_text)) desc) as rank_ix
78+
from
79+
page
80+
where
81+
fts @@ websearch_to_tsquery(query_text)
82+
order by rank_ix
83+
limit least(match_count, 30) * 2
84+
),
85+
semantic as (
86+
select
87+
id,
88+
row_number() over (order by embedding <#> query_embedding) as rank_ix
89+
from
90+
page
91+
order by rank_ix
92+
limit least(match_count, 30) * 2
93+
)
94+
select
95+
page.*
96+
from
97+
full_text
98+
full outer join semantic
99+
on full_text.id = semantic.id
100+
join page
101+
on coalesce(full_text.id, semantic.id) = page.id
102+
order by
103+
coalesce(1.0 / (rrf_k + full_text.rank_ix), 0.0) * full_text_weight +
104+
coalesce(1.0 / (rrf_k + semantic.rank_ix), 0.0) * semantic_weight
105+
desc
106+
limit
107+
least(match_count, 30)
108+
$$;
109+
110+
111+
`);
112+
113+
// fts 컬럼 추가 (존재 여부 확인 후 추가)
114+
await this.dataSource.query(`
115+
DO $$
116+
BEGIN
117+
IF NOT EXISTS (SELECT 1 FROM pg_attribute
118+
WHERE attrelid = 'page'::regclass
119+
AND attname = 'fts')
120+
THEN
121+
ALTER TABLE page
122+
ADD COLUMN fts tsvector GENERATED ALWAYS AS (to_tsvector('english', document)) STORED;
123+
END IF;
124+
END $$;
125+
`);
18126
}
19127

20128
constructor(@InjectDataSource() private dataSource: DataSource) {

apps/backend/src/page/page.service.spec.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ describe('PageService', () => {
9898
node: null,
9999
emoji: null,
100100
workspace: workspace1,
101+
document: null,
102+
fts: null,
101103
};
102104

103105
// 노드 엔티티
@@ -176,6 +178,8 @@ describe('PageService', () => {
176178
version: 1,
177179
emoji: null,
178180
workspace: null,
181+
document: null,
182+
fts: null,
179183
};
180184
// createQueryBuilder를 모킹
181185
const createQueryBuilderMock = jest.fn().mockReturnThis();

apps/backend/src/tasks/tasks.service.ts

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,15 @@ export class TasksService {
9090
}
9191
if (content) {
9292
updateFields.push(`content = $${sequence++}`);
93+
updateFields.push(`document = $${sequence++}`);
9394
updateFields.push(`embedding = $${sequence++}`);
95+
96+
// document는 JSON 타입에서 의미있는 문자열만 뽑아서 합친 문자열
97+
const document = this.extractTextValues(JSON.parse(content));
9498
// content가 있으면 임베딩 진행
95-
const vector = await embeddings.embedDocuments([
96-
this.extractTextValues(JSON.parse(content)),
97-
]);
99+
const vector = await embeddings.embedDocuments([document]);
98100
params.push(content);
101+
params.push(document);
99102
params.push(`[${vector[0].join(',')}]`);
100103
}
101104
if (emoji) {

0 commit comments

Comments
 (0)