diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml
new file mode 100755
index 0000000..73fa5dc
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/config.yml
@@ -0,0 +1,11 @@
+blank_issues_enabled: false
+issue_templates:
+ - name: Feature Template
+ description: Suggest an feature for this project ๐ฉโ๐ป
+ file: feature.md
+ - name: Experiment Template
+ description: Suggest an experiment for this project ๐ง๐ปโ๐ฌ
+ file: experiment.md
+ - name: Research Template
+ description: Suggest an research to generate ideas ๐จโ๐ซ
+ file: research.md
diff --git a/.github/ISSUE_TEMPLATE/experiment.md b/.github/ISSUE_TEMPLATE/experiment.md
new file mode 100755
index 0000000..2ad250e
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/experiment.md
@@ -0,0 +1,38 @@
+---
+name: ๐ Experiment Request
+about: Suggest an experiment for this project ๐ง๐ปโ๐ฌ
+title: "[EXP]"
+labels: experiment
+assignees:
+---
+# ๐ Experiment
+
+## ๐ฅ ์คํ ๊ทผ๊ฑฐ
+
+- ๋ ํผ๋ฐ์ค (๋
ผ๋ฌธ, ๊ฐ์, ํฌ์คํ
)
+- ํฉ๋ฆฌ์ ์ถ๋ก
+
+## ๐ ๋ด์ฉ
+
+- ์คํ์ ๋ํ ์์ธํ ๋ด์ฉ
+- ์คํ ํ๊ฒฝ๊ณผ ๋ณ์ธ ํต์ ๋ฐ๋์ ๊ธฐ์
+
+## ๐ฃ ์์ ๊ฒฐ๊ณผ
+
+- ๋ฐ๋์ ์ด์ ์ ํจ๊ป ์์ ๊ฒฐ๊ณผ ์์ฑ
+
+## ๐ณ ์ค์ ๊ฒฐ๊ณผ
+
+- ์์ ๊ฒฐ๊ณผ์ ๋ฌ๋๋ค๋ฉด ๊ทธ ์ด์ ๋ ํจ๊ป ์์ฑ
+
+## ๐ ์คํ ์ ๋ณด
+
+- wandb ๋งํฌ
+- ์ ์ถ ๊ฒฐ๊ณผ ๋
ธ์
๋งํฌ
+
+## ๐ ์ฒดํฌ๋ฆฌ์คํธ
+
+- [ ] todo 1
+- [ ] todo 2
+
+---
diff --git a/.github/ISSUE_TEMPLATE/feature.md b/.github/ISSUE_TEMPLATE/feature.md
new file mode 100755
index 0000000..eaaddba
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/feature.md
@@ -0,0 +1,20 @@
+---
+name: ๐ Feature Request
+about: Suggest an feature for this project ๐ฉโ๐ป
+title: "[FEAT]"
+labels: enhancement
+assignees:
+---
+# ๐ Feature
+
+## ๐ ๋ด์ฉ
+
+- context 1
+- context 2
+
+## ๐ ์ฒดํฌ๋ฆฌ์คํธ
+
+- [ ] todo 1
+- [ ] todo 2
+
+---
diff --git a/.github/ISSUE_TEMPLATE/research.md b/.github/ISSUE_TEMPLATE/research.md
new file mode 100755
index 0000000..4395cc6
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/research.md
@@ -0,0 +1,25 @@
+---
+name: ๐ Research Request
+about: Suggest an research to generate ideas ๐จโ๐ซ
+title: "[RES]"
+labels: research
+assignees:
+---
+# ๐ Research
+
+## ๐ ๋ด์ฉ
+
+- context 1
+- context 2
+
+## ๐ง ๊ฒฐ๋ก ๋ฐ ์คํ ๊ฐ๋ฅ์ฑ
+
+- conclusion 1
+- conclusion 2
+
+## ๐ ์ฒดํฌ๋ฆฌ์คํธ
+
+- [ ] todo 1
+- [ ] todo 2
+
+---
diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
new file mode 100644
index 0000000..c8e6346
--- /dev/null
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -0,0 +1,11 @@
+## Description
+
+- ์ด๋ฒ PR์์ ์์
ํ ๋ด์ฉ์ ๊ฐ๋ตํ ์ค๋ช
+
+## Refer to the reviewer
+
+- ๋ฆฌ๋ทฐ์ด์๊ฒ ํ์ํ ์ค๋ช
์ด๋ ํน๋ณํ ๋ด์ฃผ์์ผ๋ฉด ํ๋ ๋ถ๋ถ์ ์์ฑ
+
+## Related Issue
+
+- #์ด์๋ฒํธ
diff --git a/.github/workflows/check-lint.yml b/.github/workflows/check-lint.yml
new file mode 100755
index 0000000..c121e18
--- /dev/null
+++ b/.github/workflows/check-lint.yml
@@ -0,0 +1,24 @@
+name: check-lint
+
+on: [pull_request]
+
+jobs:
+ check-lint:
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v4
+
+ - name: Set up Python 3.11
+ uses: actions/setup-python@v5
+ with:
+ python-version: "3.11"
+
+ - name: Install dependencies
+ run: |
+ python3 -m pip install --upgrade pip
+
+ - name: Check Lint
+ run: |
+ make quality
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..4c78c18
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,177 @@
+# Custom
+.idea/
+**/data/
+**/output/
+**/outputs/
+**/wandb/
+**/*.out
+config/*.yaml
+config/token.json
+config/credentials.json
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
+.pdm.toml
+.pdm-python
+.pdm-build/
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+#.idea/
+
+# Mac
+**/.DS_Store
+.vscode/settings.json
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000..660d22a
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,25 @@
+repos:
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v4.5.0
+ hooks:
+ - id: trailing-whitespace
+ - id: end-of-file-fixer
+ - id: check-yaml
+ - id: check-added-large-files
+ - id: check-merge-conflict
+
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.7.2
+ hooks:
+ - id: ruff
+ args: [--fix]
+ - id: ruff-format
+
+ - repo: local
+ hooks:
+ - id: pytest
+ name: pytest
+ entry: python3 -m pytest
+ language: system
+ pass_filenames: false
+ types: [python]
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000..1dc5558
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,48 @@
+clean: clean-pyc clean-test
+quality: set-style-dep check-quality
+style: set-style-dep set-style
+setup: set-precommit set-style-dep set-test-dep set-git set-dev
+test: set-test-dep set-test
+
+
+##### basic #####
+set-git:
+ git config --local commit.template .gitmessage
+
+set-style-dep:
+ pip3 install ruff==0.7.2
+
+set-test-dep:
+ pip3 install pytest==8.3.2
+
+set-precommit:
+ pip3 install pre-commit==4.0.1
+ pre-commit install
+
+set-dev:
+ pip3 install -r ./requirements.txt
+
+set-test:
+ python3 -m pytest tests/
+
+set-style:
+ ruff check --fix .
+ ruff format .
+
+check-quality:
+ ruff check .
+ ruff format --check .
+
+##### clean #####
+clean-pyc:
+ find . -name '*.pyc' -exec rm -f {} +
+ find . -name '*.pyo' -exec rm -f {} +
+ find . -name '*~' -exec rm -f {} +
+ find . -name '__pycache__' -exec rm -fr {} +
+
+clean-test:
+ rm -f .coverage
+ rm -f .coverage.*
+ rm -rf .pytest_cache
+ rm -rf .mypy_cache
+ rm -rf .ruff_cache
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..b7aaf50
--- /dev/null
+++ b/README.md
@@ -0,0 +1,196 @@
+
+
+# ๐ Lv.2 NLP Project : ์๋ฅ ๋ฌธ์ ํ์ด AI ๋ชจ๋ธ ์์ฑ
+
+
+## โ๏ธ ๋ํ ์๊ฐ
+| ํน์ง | ์ค๋ช
|
+|:------:|--------------------------------------------------------------------------------------------------------|
+| ๋ํ ์ฃผ์ | ๋ค์ด๋ฒ ๋ถ์คํธ์บ ํ AI-Tech 7๊ธฐ NLP ํธ๋์ level 2 Generation for NLP ๋ํ
|
+| ๋ํ ์ค๋ช
| ํ๊ตญ์ด์ ํน์ฑ๊ณผ ์๋ฅ ์ํ์ ํน์ง์ ๋ฐํ์ผ๋ก ์๋ฅ์ ํนํ๋ AI ๋ชจ๋ธ์ ์์ฑํ๋ ํ๋ก์ ํธ |
+| ์งํ ๊ธฐ๊ฐ |2024๋
11์ 11์ ~ 2024๋
11์ 28์ผ|
+| ๋ฐ์ดํฐ ๊ตฌ์ฑ | ํ์ต๋ฐ์ดํฐ ์
: KMMLU / MMMLU(Ko) / KLUE MRC ์ค 2031๊ฐํ๊ฐ๋ฐ์ดํฐ ์
: ์๋ฅํ ๋ฌธ์ + KMMLU / MMMLU(Ko) / KLUE MRC ์ด 869๊ฐ |
+| ํ๊ฐ ์งํ | ์ ํ๋(Accuracy) = ๋ชจ๋ธ์ด ๋ง์ถ ๋ฌธ์ ์ / ์ ์ฒด ๋ฌธ์ ์ |
+
+## ๐๏ธ Leader Board
+### ๐ฅ Public Leader Board (2์)
+
+### ๐ฅ Priavate Leader Board (2์)
+
+
+## ๐จโ๐ป Contributors
+
+
+## ๐ผ ์ญํ ๋ถ๋ด
+| ์ด๋ฆ | ์ญํ |
+| --- |---------------------------------------------------------------------------------------------|
+| ๊น๋ฏผ์ | ์ต์ ํ ์๋ฃจ์
(DeepSpeed), ์์ํ(Optimizer Quantization), ๋์ด๋ ๊ธฐ๋ฐ ๋ฐ์ดํฐ ์ฆ๊ฐ |
+| ๊น์์ง | EDA(๊ตญ์ด์์ญ๊ณผ ์ฌํ์์ญ ์ฐจ์ด ๋ถ์), ๋ฐ์ดํฐ ์์ง, LLM์ ํ์ฉํ ๋ฐ์ดํฐ ์ฆ๊ฐ, ํ๋กฌํํธ ์คํ |
+| ์๊ฐ์ฐ | EDA(๊ตญ์ด์์ญ๊ณผ ์ฌํ์์ญ ์ฐจ์ด ๋ถ์), ๋ฐ์ดํฐ ์์ง, RAG ๊ตฌํ(Dense Retrieval) |
+| ์ด์์ | ๋ฉ๋ชจ๋ฆฌ/์๋ ์ต์ ํ, ์์ํ(BitsAndBytes, GPTQ), ๋ฐ์ดํฐ ์์ง, ๋ฐ์ดํฐ ์ ์ , RAG ๊ตฌํ(Elastic Search, Reranker, RAFT) |
+| ํ์ฑ๋ฏผ | EDA(๋ฐ์ดํฐ ์ถ์ฒ ๊ธฐ๋ฐ ๋ถ์), LLM์ ํ์ฉํ ๋ฐ์ดํฐ ์ฆ๊ฐ |
+| ํ์ฑ์ฌ | EDA(๊ตญ์ด์์ญ๊ณผ ์ฌํ์์ญ ์ฐจ์ด ๋ถ์), streamlit ์๊ฐํ |
+
+## ๐ Results
+
+
+## ๐ ๏ธ**Dependencies**
+```
+# CUDA Version: 12.2
+# Ubuntu 20.04.6
+# python 3.10.13
+
+# Deep Learning
+auto_gptq==0.7.1
+bitsandbytes==0.44.1
+evaluate==0.4.3
+huggingface-hub==0.26.2
+numpy==2.0.0
+optimum==1.23.3
+peft==0.5.0
+scikit-learn==1.5.2
+torch==2.5.1 # 2.5.1+cu124
+tqdm==4.67.0
+transformers==4.46.2
+trl==0.12.0
+wandb==0.18.5
+
+# RAG
+elasticsearch==8.16.0
+konlpy==0.6.0
+rank-bm25==0.2.2
+wikiextractor==3.0.6
+faiss-cpu==1.9.0 # faiss-gpu==1.7.2
+
+# Utils
+beautifulsoup4==4.12.3
+ipykernel==6.29.5
+ipywidgets==8.1.5
+loguru==0.7.2
+matplotlib==3.9.2
+python-dotenv==1.0.1
+reportlab==4.2.5
+streamlit==1.40.1
+pdfminer.six==20240706
+
+# Google Drive API
+google-api-python-client==2.151.0
+google-auth-httplib2==0.2.0
+google-auth-oauthlib==1.2.1
+
+# Automatically installed dependencies
+# pandas==2.2.3
+# pyarrow==18.0.0
+# datasets==3.1.0
+# safetensors==0.4.5
+# scipy==1.14.1
+# tqdm==4.67.0
+# PyYAML==6.0.2
+# requests==2.32.3
+
+```
+## ๐พ Usage
+1. Setting
+```
+$ pip install -r requirements.txt
+```
+2. train & inference
+```angular2html
+$ python3 code/main.py
+```
+
+## ๐ ํ๋ก์ ํธ ๊ตฌ์กฐ
+```
+code
+ โฃ rag
+ โ โฃ data_process
+ โ โ โฃ external_data.py
+ โ โ โ wiki_dump.py
+ โ โฃ README.md
+ โ โฃ __init__.py
+ โ โฃ chunk_data.py
+ โ โฃ dpr_data.py
+ โ โฃ encoder.py
+ โ โฃ index_runner.py
+ โ โฃ indexers.py
+ โ โฃ prepare_dense.py
+ โ โฃ reranker.py
+ โ โฃ retriever.py
+ โ โฃ retriever_bm25.py
+ โ โฃ retriever_elastic.py
+ โ โฃ train.py
+ โ โฃ trainer.py
+ โ โ utils.py
+ โฃ utils
+ โ โฃ __init__.py
+ โ โฃ common.py
+ โ โฃ gdrive_manager.py
+ โ โ hf_manager.py
+ โฃ data_loaders.py
+ โฃ inference.py
+ โฃ labeling.py
+ โฃ main.py
+ โฃ model.py
+ โฃ split.py
+ โ trainer.py
+ data_aug
+ โฃ add_CoT.py
+ โ aug_philo.py
+ data_process
+ โฃ crawling_gichulpass.py
+ โฃ external_musr.py
+ โฃ external_race.py
+ โฃ external_sat_gaokao.py
+ โฃ pdf_to_txt.py
+ โฃ process_balance_choices.py
+ โฃ process_formatting.py
+ โ process_google_translate.py
+ data_viz
+ โฃ csv2pdf.py
+ โฃ labeling.py
+ โ streamlit_app.py
+config
+ โฃ sample
+ โ โฃ config.yaml
+ โ โ env-sample.txt
+ โ elastic_setting.json
+```
diff --git a/assets/final_result.png b/assets/final_result.png
new file mode 100644
index 0000000..153ef8a
Binary files /dev/null and b/assets/final_result.png differ
diff --git a/assets/private_rank.png b/assets/private_rank.png
new file mode 100644
index 0000000..20e1c94
Binary files /dev/null and b/assets/private_rank.png differ
diff --git a/assets/public_rank.png b/assets/public_rank.png
new file mode 100644
index 0000000..5dfb361
Binary files /dev/null and b/assets/public_rank.png differ
diff --git a/code/data_loaders.py b/code/data_loaders.py
new file mode 100644
index 0000000..5ab6565
--- /dev/null
+++ b/code/data_loaders.py
@@ -0,0 +1,385 @@
+from ast import literal_eval
+import os
+import pickle
+from typing import Dict, List
+
+from datasets import Dataset
+from dotenv import load_dotenv
+from loguru import logger
+import numpy as np
+import pandas as pd
+from rag import ElasticsearchRetriever, Reranker
+from rag.dpr_data import KorQuadDataset
+from rag.encoder import KobertBiEncoder
+from rag.indexers import DenseFlatIndexer
+from rag.retriever import KorDPRRetriever, get_passage_file
+from utils import load_config
+
+
+class DataLoader:
+ def __init__(self, tokenizer, data_config):
+ self.tokenizer = tokenizer
+ self.retriever_config = data_config["retriever"]
+ self.train_path = data_config["train_path"]
+ self.test_path = data_config["test_path"]
+ self.processed_train_path = data_config["processed_train_path"]
+ self.processed_test_path = data_config["processed_test_path"]
+ self.max_seq_length = data_config["max_seq_length"]
+ self.test_size = data_config["test_size"]
+ self.prompt_config = data_config["prompt"]
+
+ def prepare_datasets(self, is_train):
+ """ํ์ต ๋๋ ํ
์คํธ์ฉ ๋ฐ์ดํฐ์
์ค๋น"""
+ # prompt ์ ์ฒ๋ฆฌ๋ ๋ฐ์ดํฐ์
ํ์ผ์ด ์กด์ฌํ๋ค๋ฉด ์ด๋ฅผ ๋ก๋ํฉ๋๋ค.
+ processed_df_path = self.processed_train_path if is_train else self.processed_test_path
+ if os.path.isfile(processed_df_path):
+ logger.info(f"์ ์ฒ๋ฆฌ๋ ๋ฐ์ดํฐ์
์ ๋ถ๋ฌ์ต๋๋ค: {processed_df_path}")
+ processed_df = pd.read_csv(processed_df_path, encoding="utf-8")
+ processed_df["messages"] = processed_df["messages"].apply(literal_eval)
+ processed_dataset = Dataset.from_pandas(processed_df)
+ else:
+ dataset = self._load_data(is_train)
+ processed_dataset = self._process_dataset(dataset, is_train)
+
+ if is_train:
+ tokenized_dataset = self._tokenize_dataset(processed_dataset)
+ splitted_dataset = self._split_dataset(tokenized_dataset)
+ return splitted_dataset
+ return processed_dataset
+
+ def _retrieve(self, df): # noqa: C901
+ if self.retriever_config["retriever_type"] == "Elasticsearch":
+ retriever = ElasticsearchRetriever(
+ index_name=self.retriever_config["index_name"],
+ )
+ elif self.retriever_config["retriever_type"] == "DPR":
+ # KorDPRRetriever ์ฌ์ฉ
+ try:
+ model = KobertBiEncoder() # ๋ชจ๋ธ ์ด๊ธฐํ
+ model.load("./rag/output/my_model.pt") # ๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ
+ logger.debug("Model loaded successfully.")
+ assert model is not None, "Model is None after loading."
+ except Exception as e:
+ logger.debug(f"Error while loading model: {e}")
+
+ try:
+ valid_dataset = KorQuadDataset("./rag/data/KorQuAD_v1.0_dev.json") # ๋ฐ์ดํฐ์
์ค๋น
+ logger.debug("Valid dataset loaded successfully.")
+ except Exception as e:
+ logger.debug(f"Error while loading valid dataset: {e}")
+
+ try:
+ index = DenseFlatIndexer() # ์ธ๋ฑ์ค ์ค๋น
+ index.deserialize(path="./rag/2050iter_flat/")
+ logger.debug("Index loaded successfully.")
+ assert index is not None, "Index is None after loading."
+ except Exception as e:
+ logger.debug(f"Error while loading index: {e}")
+
+ ds_retriever = KorDPRRetriever(model=model, valid_dataset=valid_dataset, index=index)
+ logger.debug("KorDPRRetriever initialized successfully.")
+ else:
+ return [""] * len(df)
+
+ def _combine_text(row):
+ # NaN ๊ฐ ์ฒ๋ฆฌ
+ paragraph = "" if pd.isna(row["paragraph"]) else str(row["paragraph"])
+ if pd.isna(row["problems"]):
+ problems = {"question": "", "choices": []}
+ else:
+ problems = row["problems"]
+ question = str(problems.get("question", ""))
+ choices = [str(choice) for choice in problems.get("choices", [])]
+
+ if self.retriever_config["query_type"] == "pqc":
+ return paragraph + " " + question + " " + " ".join(choices)
+ if self.retriever_config["query_type"] == "pq":
+ return paragraph + " " + question
+ if self.retriever_config["query_type"] == "pc":
+ return paragraph + " " + " ".join(choices)
+ else:
+ return paragraph
+
+ top_k = self.retriever_config["top_k"]
+ threshold = self.retriever_config["threshold"]
+ query_max_length = self.retriever_config["query_max_length"]
+
+ queries = df.apply(_combine_text, axis=1)
+ if self.retriever_config["retriever_type"] == "Elasticsearch":
+ filtered_queries = [(i, q) for i, q in enumerate(queries) if len(q) <= query_max_length]
+ if not filtered_queries:
+ return [""] * len(queries)
+
+ indices, valid_queries = zip(*filtered_queries)
+ retrieve_results = retriever.bulk_retrieve(valid_queries, top_k)
+ rerank_k = self.retriever_config["rerank_k"]
+ if rerank_k > 0:
+ with Reranker() as reranker:
+ retrieve_results = reranker.rerank(valid_queries, retrieve_results, rerank_k)
+ # [[{"text":"์๋
ํ์ธ์", "score":0.5}, {"text":"๋ฐ๊ฐ์ต๋๋ค", "score":0.3},],]
+
+ docs = [""] * len(queries)
+ for idx, result in zip(indices, retrieve_results):
+ docs[idx] = " ".join(item["text"] for item in result if item["score"] >= threshold)
+ docs[idx] = docs[idx][: self.retriever_config["result_max_length"]]
+ elif self.retriever_config["retriever_type"] == "DPR": # DPR์ธ ๊ฒฝ์ฐ
+ docs = []
+ for query in queries:
+ passages = ds_retriever.retrieve(query=query, k=top_k) # DPR์ผ๋ก ๊ฒ์
+
+ # passage ๋ก๋ฉ ๋ฐ ๊ฒฐํฉ
+ for idx, (passage, score) in enumerate(passages):
+ # passage ID์ ํด๋นํ๋ ํ์ผ ๊ฒฝ๋ก ๊ฐ์ ธ์ค๊ธฐ
+ path = get_passage_file([idx])
+ if path:
+ with open(path, "rb") as f:
+ passage_dict = pickle.load(f)
+ docs.append((passage_dict[idx], score)) # passage์ score ์ ์ฅ
+ else:
+ logger.debug(f"No passage found for ID: {idx}")
+
+ # ๋ก๊น
์ถ๊ฐ
+ logger.info(f"๊ฐ์ฐ Query: {query}")
+ logger.info(f"Rank {idx+1}: Score: {score:.4f}, Passage: {passage}")
+
+ return docs
+
+ def _load_data(self, is_train) -> List[Dict]:
+ """csv๋ฅผ ์ฝ์ด์ค๊ณ dictionary ๋ฐฐ์ด ํํ๋ก ๋ณํํฉ๋๋ค."""
+ file_path = self.train_path if is_train else self.test_path
+ df = pd.read_csv(file_path)
+ df["problems"] = df["problems"].apply(literal_eval)
+ docs = self._retrieve(df)
+ records = []
+ for idx, row in df.iterrows():
+ problems = row["problems"]
+ record = {
+ "id": row["id"],
+ "paragraph": row["paragraph"],
+ "question": problems["question"],
+ "choices": problems["choices"],
+ "answer": problems.get("answer", None),
+ "question_plus": problems.get("question_plus", None),
+ "document": docs[idx],
+ }
+ records.append(record)
+ logger.info("dataset ๋ก๋ ๋ฐ retrive ์๋ฃ.")
+ return records
+
+ def _process_dataset(self, dataset: List[Dict], is_train=True):
+ """๋ฐ์ดํฐ์ ํ๋กฌํํธ ์ ์ฉ"""
+
+ # ๋ฐ์ดํฐ์
์ prompt ์ ์ฒ๋ฆฌํ๊ณ ์ ์ฅํฉ๋๋ค.
+ logger.info("๋ฐ์ดํฐ์
์ ์ฒ๋ฆฌ๋ฅผ ์ํํฉ๋๋ค.")
+ processed_data = []
+ for row in dataset:
+ choices_string = "\n".join([f"{idx + 1} - {choice}" for idx, choice in enumerate(row["choices"])])
+
+ # start
+ if row["question_plus"]:
+ message_start = self.prompt_config["start_with_plus"].format(
+ paragraph=row["paragraph"],
+ question=row["question"],
+ question_plus=row["question_plus"],
+ choices=choices_string,
+ )
+ else:
+ message_start = self.prompt_config["start"].format(
+ paragraph=row["paragraph"],
+ question=row["question"],
+ choices=choices_string,
+ )
+ # mid
+ if row["document"]:
+ message_mid = self.prompt_config["mid_with_document"].format(
+ document=row["document"],
+ )
+ else:
+ message_mid = self.prompt_config["mid"]
+ # end
+ message_end = self.prompt_config["end"]
+
+ user_message = message_start + message_mid + message_end
+ messages = [
+ {"role": "system", "content": "์ง๋ฌธ์ ์ฝ๊ณ ์ง๋ฌธ์ ๋ต์ ๊ตฌํ์ธ์."},
+ {"role": "user", "content": user_message},
+ ]
+
+ if is_train:
+ messages.append({"role": "assistant", "content": f"{row['answer']}"})
+
+ processed_data.append({"id": row["id"], "messages": messages, "label": row["answer"] if is_train else None})
+
+ processed_df = pd.DataFrame(processed_data)
+ logger.info("๋ฐ์ดํฐ์
์ ์ฒ๋ฆฌ๊ฐ ์๋ฃ๋์์ต๋๋ค.")
+ processed_df_path = self.processed_train_path if is_train else self.processed_test_path
+ if processed_df_path:
+ processed_df.to_csv(processed_df_path, index=False, encoding="utf-8")
+ logger.info("์ ์ฒ๋ฆฌ๋ ๋ฐ์ดํฐ์
์ด ์ ์ฅ๋์์ต๋๋ค.")
+ return Dataset.from_pandas(processed_df)
+
+ def _tokenize_dataset(self, dataset):
+ def formatting_prompts_func(example):
+ output_texts = []
+ for i in range(len(example["messages"])):
+ output_texts.append(
+ self.tokenizer.apply_chat_template(
+ example["messages"][i],
+ tokenize=False,
+ )
+ )
+ return output_texts
+
+ def tokenize(element):
+ outputs = self.tokenizer(
+ formatting_prompts_func(element),
+ truncation=False,
+ padding=False,
+ return_overflowing_tokens=False,
+ return_length=False,
+ )
+ return {
+ "input_ids": outputs["input_ids"],
+ "attention_mask": outputs["attention_mask"],
+ }
+
+ tokenized_dataset = dataset.map(
+ tokenize,
+ remove_columns=list(dataset.features),
+ batched=True,
+ num_proc=4,
+ load_from_cache_file=True,
+ desc="Tokenizing",
+ )
+
+ # ํ ํฐ ๊ธธ์ด๊ฐ max_seq_length๋ฅผ ์ด๊ณผํ๋ ๋ฐ์ดํฐ ํํฐ๋ง
+ logger.info(f"dataset length: {len(tokenized_dataset)}")
+ tokenized_dataset = tokenized_dataset.filter(lambda x: len(x["input_ids"]) <= self.max_seq_length)
+ logger.info(f"filtered dataset length: {len(tokenized_dataset)}")
+
+ return tokenized_dataset
+
+ def _split_dataset(self, dataset):
+ split_dataset = dataset.train_test_split(test_size=self.test_size, seed=42)
+ train_dataset = split_dataset["train"]
+ eval_dataset = split_dataset["test"]
+
+ logger.debug(self.tokenizer.decode(train_dataset[0]["input_ids"], skip_special_tokens=True))
+ train_dataset_token_lengths = [len(train_dataset[i]["input_ids"]) for i in range(len(train_dataset))]
+ logger.info(f"max token length: {max(train_dataset_token_lengths)}")
+ logger.info(f"min token length: {min(train_dataset_token_lengths)}")
+ logger.info(f"avg token length: {np.mean(train_dataset_token_lengths)}")
+
+ return train_dataset, eval_dataset
+
+
+if __name__ == "__main__":
+ config_folder = os.path.join(os.path.dirname(__file__), "..", "config/")
+ load_dotenv(os.path.join(config_folder, ".env"))
+ config = load_config()
+ data_config = config["data"]
+
+ def _retrieve(retriever_config, df): # noqa: C901
+ if retriever_config["retriever_type"] == "Elasticsearch":
+ retriever = ElasticsearchRetriever(
+ index_name=retriever_config["index_name"],
+ )
+ elif retriever_config["retriever_type"] == "BM25":
+ raise NotImplementedError("BM25๋ ๋ ์ด์ ์ง์ํ์ง ์์ต๋๋ค. Elasticsearch๋ฅผ ์ฌ์ฉํด์ฃผ์ธ์...")
+
+ elif retriever_config["retriever_type"] == "DPR":
+ # KorDPRRetriever ์ฌ์ฉ
+ try:
+ model = KobertBiEncoder() # ๋ชจ๋ธ ์ด๊ธฐํ
+ model.load("./rag/output/my_model.pt") # ๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ
+ logger.debug("Model loaded successfully.")
+ assert model is not None, "Model is None after loading."
+ except Exception as e:
+ logger.debug(f"Error while loading model: {e}")
+
+ try:
+ valid_dataset = KorQuadDataset("./rag/data/KorQuAD_v1.0_dev.json") # ๋ฐ์ดํฐ์
์ค๋น
+ logger.debug("Valid dataset loaded successfully.")
+ except Exception as e:
+ logger.debug(f"Error while loading valid dataset: {e}")
+
+ try:
+ index = DenseFlatIndexer() # ์ธ๋ฑ์ค ์ค๋น
+ index.deserialize(path="./rag/2050iter_flat/")
+ logger.debug("Index loaded successfully.")
+ assert index is not None, "Index is None after loading."
+ except Exception as e:
+ logger.debug(f"Error while loading index: {e}")
+
+ ds_retriever = KorDPRRetriever(model=model, valid_dataset=valid_dataset, index=index)
+ logger.debug("KorDPRRetriever initialized successfully.")
+
+ else:
+ return [""] * len(df)
+
+ def _combine_text(row):
+ if retriever_config["query_type"] == "pqc":
+ return row["paragraph"] + " " + row["problems"]["question"] + " " + " ".join(row["problems"]["choices"])
+ if retriever_config["query_type"] == "pq":
+ return row["paragraph"] + " " + row["problems"]["question"]
+ if retriever_config["query_type"] == "pc":
+ return row["paragraph"] + " " + " ".join(row["problems"]["choices"])
+ else:
+ return row["paragraph"]
+
+ top_k = retriever_config["top_k"]
+ threshold = retriever_config["threshold"]
+ query_max_length = retriever_config["query_max_length"]
+
+ queries = df.apply(_combine_text, axis=1)
+ if retriever_config["retriever_type"] == "Elasticsearch":
+ filtered_queries = [(i, q) for i, q in enumerate(queries) if len(q) <= query_max_length]
+ if not filtered_queries:
+ return [""] * len(queries)
+
+ indices, valid_queries = zip(*filtered_queries)
+ retrieve_results = retriever.bulk_retrieve(valid_queries, top_k)
+ rerank_k = retriever_config["rerank_k"]
+ if rerank_k > 0:
+ with Reranker() as reranker:
+ retrieve_results = reranker.rerank(valid_queries, retrieve_results, rerank_k)
+ # [[{"text":"์๋
ํ์ธ์", "score":0.5}, {"text":"๋ฐ๊ฐ์ต๋๋ค", "score":0.3},],]
+
+ docs = [""] * len(queries)
+ for idx, result in zip(indices, retrieve_results):
+ docs[idx] = " ".join(
+ f"[{item['score']}]: {item['text']}" for item in result if item["score"] >= threshold
+ )
+
+ elif retriever_config["retriever_type"] == "DPR": # DPR์ธ ๊ฒฝ์ฐ
+ docs = []
+ for query in queries:
+ passages = ds_retriever.retrieve(query=query, k=top_k) # DPR์ผ๋ก ๊ฒ์
+
+ # passage ๋ก๋ฉ ๋ฐ ๊ฒฐํฉ
+ for idx, (passage, score) in enumerate(passages):
+ # passage ID์ ํด๋นํ๋ ํ์ผ ๊ฒฝ๋ก ๊ฐ์ ธ์ค๊ธฐ
+ path = get_passage_file([idx])
+ if path:
+ with open(path, "rb") as f:
+ passage_dict = pickle.load(f)
+ docs.append((passage_dict[idx], score)) # passage์ score ์ ์ฅ
+ else:
+ logger.debug(f"No passage found for ID: {idx}")
+
+ # ๋ก๊น
์ถ๊ฐ
+ logger.info(f"Query: {query}")
+ logger.info(f"Rank {idx+1}: Score: {score:.4f}, Passage: {passage}")
+ return docs
+
+ def load_and_save(retriever_config, file_path) -> List[Dict]:
+ """csv๋ฅผ ์ฝ์ด์ค๊ณ dictionary ๋ฐฐ์ด ํํ๋ก ๋ณํํฉ๋๋ค."""
+ df = pd.read_csv(file_path)
+ df["problems"] = df["problems"].apply(literal_eval)
+ docs = _retrieve(retriever_config, df)
+ df["documents"] = docs
+ df.to_csv(file_path.replace(".csv", "_retrieve.csv"), index=False)
+ logger.debug("retrieve ๊ฒฐ๊ณผ๊ฐ csv๋ก ์ ์ฅ๋์์ต๋๋ค.")
+
+ load_and_save(data_config["retriever"], data_config["train_path"])
+ load_and_save(data_config["retriever"], data_config["test_path"])
diff --git a/code/inference.py b/code/inference.py
new file mode 100644
index 0000000..58767f5
--- /dev/null
+++ b/code/inference.py
@@ -0,0 +1,48 @@
+from loguru import logger
+import numpy as np
+import pandas as pd
+import torch
+from tqdm import tqdm
+
+
+class InferenceModel:
+ def __init__(self, inference_config, model, tokenizer, test_dataset):
+ self.inference_config = inference_config
+ self.model = model
+ self.tokenizer = tokenizer
+ self.test_dataset = test_dataset
+ self.pred_choices_map = {0: "1", 1: "2", 2: "3", 3: "4", 4: "5"}
+
+ def run_inference(self):
+ if not self.inference_config["do_test"]:
+ logger.info("์ถ๋ก ๋จ๊ณ๋ฅผ ์๋ตํฉ๋๋ค. inference do_test ์ค์ ์ ํ์ธํ์ธ์.")
+ return
+
+ results = self._inference(self.test_dataset)
+ return self._save_results(results)
+
+ def _inference(self, test_dataset):
+ infer_results = []
+ self.model.config.use_cache = True
+ self.model.eval()
+
+ with torch.inference_mode():
+ for example in tqdm(test_dataset):
+ outputs = self.model(
+ self.tokenizer.apply_chat_template(
+ example["messages"], tokenize=True, add_generation_prompt=True, return_tensors="pt"
+ ).to("cuda")
+ )
+
+ logits = outputs.logits[:, -1].flatten().cpu()
+ target_logits = [logits[self.tokenizer.vocab[str(i + 1)]] for i in range(5)] # ์ ํ์ง๋ ํญ์ 5๊ฐ
+ probs = torch.nn.functional.softmax(torch.tensor(target_logits, dtype=torch.float32), dim=-1)
+ predict_value = self.pred_choices_map[np.argmax(probs.detach().cpu().numpy())]
+
+ infer_results.append({"id": example["id"], "answer": predict_value})
+
+ return infer_results
+
+ def _save_results(self, results):
+ logger.info(self.inference_config["output_path"])
+ pd.DataFrame(results).to_csv(self.inference_config["output_path"], index=False)
diff --git a/code/labeling.py b/code/labeling.py
new file mode 100644
index 0000000..e9956c7
--- /dev/null
+++ b/code/labeling.py
@@ -0,0 +1,119 @@
+from cleanlab.classification import CleanLearning
+from cleanlab.filter import find_label_issues
+from loguru import logger
+import numpy as np
+import pandas as pd
+from sklearn.cluster import KMeans
+from sklearn.feature_extraction.text import TfidfVectorizer
+from sklearn.model_selection import StratifiedKFold
+from xgboost import XGBClassifier
+
+
+# Pandas ์ถ๋ ฅ ์ค์
+pd.set_option("display.max_columns", None)
+pd.set_option("display.max_rows", None)
+pd.set_option("display.max_colwidth", None)
+
+
+def create_initial_labels(input_file, output_file, num_clusters=2):
+ """TF-IDF์ K-means๋ฅผ ์ฌ์ฉํ์ฌ ์ด๊ธฐ ๋ผ๋ฒจ์ ์์ฑํฉ๋๋ค."""
+ df = pd.read_csv(input_file)
+ df.dropna(subset=["paragraph", "problems"], inplace=True)
+ df["combined_text"] = df["paragraph"] + " " + df["problems"]
+
+ vectorizer = TfidfVectorizer()
+ X = vectorizer.fit_transform(df["combined_text"])
+
+ kmeans = KMeans(n_clusters=num_clusters, random_state=42)
+ kmeans.fit(X)
+
+ df["target"] = kmeans.labels_
+ final_columns = ["id", "paragraph", "problems", "question_plus", "target"]
+ df[final_columns].to_csv(output_file, index=False)
+ logger.info(f"์ด๊ธฐ ๋ผ๋ฒจ๋ง์ด ์๋ฃ๋์์ต๋๋ค. ๊ฒฐ๊ณผ๊ฐ {output_file}์ ์ ์ฅ๋์์ต๋๋ค.")
+
+
+def load_and_preprocess_data(file_path):
+ """๋ฐ์ดํฐ๋ฅผ ๋ก๋ํ๊ณ ์ ์ฒ๋ฆฌํฉ๋๋ค."""
+ df = pd.read_csv(file_path)
+ if "target" not in df.columns:
+ raise ValueError("๋ฐ์ดํฐ์
์ 'target' ์ด์ด ์์ต๋๋ค.")
+
+ X = df[["paragraph", "problems"]].astype(str).agg(" ".join, axis=1)
+ y = df["target"].astype(int)
+ return df, X, y
+
+
+def vectorize_text(X, max_features=5000):
+ """ํ
์คํธ ๋ฐ์ดํฐ๋ฅผ ๋ฒกํฐํํฉ๋๋ค."""
+ vectorizer = TfidfVectorizer(max_features=max_features)
+ return vectorizer.fit_transform(X)
+
+
+def train_and_predict(X_vectorized, y, n_splits=5):
+ """๋ชจ๋ธ์ ํ๋ จํ๊ณ ์์ธก ํ๋ฅ ์ ๋ฐํํฉ๋๋ค."""
+ base_model = XGBClassifier(eval_metric="mlogloss", n_estimators=100)
+ model = CleanLearning(base_model)
+ skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
+ pred_probs = np.zeros((len(y), len(np.unique(y))))
+
+ for fold, (train_index, val_index) in enumerate(skf.split(X_vectorized, y), 1):
+ logger.info(f"Fold {fold}/{n_splits}")
+ X_train, X_val = X_vectorized[train_index], X_vectorized[val_index]
+ y_train, _ = y[train_index], y[val_index]
+ model.fit(X_train, y_train)
+ pred_probs[val_index] = model.predict_proba(X_val)
+
+ return pred_probs
+
+
+def find_and_update_label_issues(df, y, pred_probs):
+ """๋ ์ด๋ธ ์ด์๋ฅผ ์ฐพ๊ณ ๋ฐ์ดํฐํ๋ ์์ ์
๋ฐ์ดํธํฉ๋๋ค."""
+ label_issues = find_label_issues(labels=y, pred_probs=pred_probs, return_indices_ranked_by="self_confidence")
+ df["is_label_issue"] = False
+ df.loc[label_issues, "is_label_issue"] = True
+ df["suggested_label"] = np.argmax(pred_probs, axis=1)
+ return df
+
+
+def save_and_print_results(df, output_file):
+ """๊ฒฐ๊ณผ๋ฅผ ์ ์ฅํ๊ณ ์ถ๋ ฅํฉ๋๋ค."""
+ final_columns = ["id", "paragraph", "problems", "question_plus", "target", "suggested_label", "is_label_issue"]
+ df[final_columns].to_csv(output_file, index=False)
+
+ logger.info("\nID์ ์ ์๋ ๋ ์ด๋ธ:")
+ logger.info(df[["id", "suggested_label"]].to_string(index=False))
+
+ logger.info("\n๋ ์ด๋ธ ์ด์ ํต๊ณ:")
+ logger.info(df["is_label_issue"].value_counts(normalize=True))
+
+ logger.info("\n์๋ ๋ ์ด๋ธ๊ณผ ์ ์๋ ๋ ์ด๋ธ ๋น๊ต:")
+ logger.info(pd.crosstab(df["target"], df["suggested_label"]))
+
+
+def main():
+ initial_input_file = "../data/train.csv"
+ initial_output_file = "../data/output_with_labels.csv"
+ final_output_file = "../data/cleaned_output_with_labels_CL.csv"
+
+ # ์ด๊ธฐ ๋ผ๋ฒจ๋ง ์ํ
+ create_initial_labels(initial_input_file, initial_output_file)
+
+ # ๋ฐ์ดํฐ ๋ก๋ ๋ฐ ์ ์ฒ๋ฆฌ
+ df, X, y = load_and_preprocess_data(initial_output_file)
+
+ # ํ
์คํธ ๋ฒกํฐํ
+ X_vectorized = vectorize_text(X)
+
+ # ๋ชจ๋ธ ํ๋ จ ๋ฐ ์์ธก
+ pred_probs = train_and_predict(X_vectorized, y)
+
+ # ๋ ์ด๋ธ ์ด์ ์ฐพ๊ธฐ ๋ฐ ์
๋ฐ์ดํธ
+ df = find_and_update_label_issues(df, y, pred_probs)
+
+ # ๊ฒฐ๊ณผ ์ ์ฅ ๋ฐ ์ถ๋ ฅ
+ save_and_print_results(df, final_output_file)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/code/main.py b/code/main.py
new file mode 100644
index 0000000..a8733ec
--- /dev/null
+++ b/code/main.py
@@ -0,0 +1,84 @@
+import os
+
+from data_loaders import DataLoader
+from inference import InferenceModel
+from loguru import logger
+from model import ModelHandler
+from trainer import CustomTrainer
+from utils import (
+ GoogleDriveManager,
+ create_experiment_filename,
+ load_config,
+ load_env_file,
+ log_config,
+ set_logger,
+ set_seed,
+)
+import wandb
+
+
+def main():
+ # env, config, log, seed ์ค์
+ load_env_file()
+ config = load_config()
+ set_logger(log_file=config["log"]["file"], log_level=config["log"]["level"])
+ set_seed()
+
+ # wandb ์ค์
+ exp_name = create_experiment_filename(config)
+ wandb.init(
+ config=config,
+ project=config["wandb"]["project"],
+ entity=config["wandb"]["entity"],
+ name=exp_name,
+ )
+
+ # wandb ์คํ๋ช
์ผ๋ก config ๊ฐฑ์
+ config["training"]["run_name"] = exp_name
+ config["inference"]["output_path"] = os.path.join(config["inference"]["output_path"], exp_name + "_output.csv")
+ log_config(config)
+
+ try:
+ # ๋ชจ๋ธ ๋ฐ ํ ํฌ๋์ด์ ์ค์
+ model_handler = ModelHandler(config["model"])
+ model, tokenizer = model_handler.setup()
+
+ # ํ์ต์ฉ ๋ฐ์ดํฐ ์ฒ๋ฆฌ
+ data_processor = DataLoader(tokenizer, config["data"])
+ train_dataset, eval_dataset = data_processor.prepare_datasets(is_train=True)
+ test_dataset = data_processor.prepare_datasets(is_train=False)
+
+ # ํ์ต
+ trainer = CustomTrainer(
+ training_config=config["training"],
+ model=model,
+ tokenizer=tokenizer,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ )
+ trained_model = trainer.train()
+
+ # ์ถ๋ก
+ inferencer = InferenceModel(
+ inference_config=config["inference"],
+ model=trained_model,
+ tokenizer=tokenizer,
+ test_dataset=test_dataset,
+ )
+ inferencer.run_inference()
+
+ except Exception as e:
+ logger.exception(f"Error occurred: {e}")
+ wandb.finish(exit_code=1)
+ else:
+ logger.info("Upload output & config to GDrive...")
+ gdrive_manager = GoogleDriveManager()
+ gdrive_manager.upload_exp(
+ config["exp"]["username"],
+ config["inference"]["output_path"],
+ )
+ wandb.finish()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/code/model.py b/code/model.py
new file mode 100644
index 0000000..6d7fd69
--- /dev/null
+++ b/code/model.py
@@ -0,0 +1,60 @@
+from loguru import logger
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
+
+
+class ModelHandler:
+ def __init__(self, model_config):
+ self.base_model = model_config["base_model"]
+ self.model_config = model_config["model"]
+ self.tokenizer_config = model_config["tokenizer"]
+
+ def setup(self):
+ model = self._load_model()
+ tokenizer = self._load_tokenizer()
+ return model, tokenizer
+
+ def _load_model(self):
+ torch_dtype = getattr(torch, self.model_config["torch_dtype"])
+ base_kwargs = {"trust_remote_code": True, "low_cpu_mem_usage": self.model_config["low_cpu_mem_usage"]}
+
+ if self.model_config["quantization"] == "BitsAndBytes":
+ bits = self.model_config["bits"]
+ if bits == 8:
+ quantization_config = BitsAndBytesConfig(
+ load_in_8bit=True,
+ bnb_8bit_use_double_quant=self.model_config["use_double_quant"],
+ bnb_8bit_compute_dtype=torch_dtype,
+ )
+ elif bits == 4:
+ quantization_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_quant_type="nf4",
+ bnb_4bit_use_double_quant=self.model_config["use_double_quant"],
+ bnb_4bit_compute_dtype=torch_dtype,
+ )
+ else:
+ raise ValueError(f"Unsupported bits value: {bits}")
+
+ base_kwargs["quantization_config"] = quantization_config
+ elif self.model_config["quantization"] == "auto":
+ base_kwargs["torch_dtype"] = "auto"
+ base_kwargs["device_map"] = "auto"
+ else:
+ base_kwargs["torch_dtype"] = torch_dtype
+
+ logger.debug(f"base_kwargs: {base_kwargs}")
+ model = AutoModelForCausalLM.from_pretrained(self.base_model, **base_kwargs)
+ model.config.use_cache = self.model_config["use_cache"]
+ return model
+
+ def _load_tokenizer(self):
+ tokenizer = AutoTokenizer.from_pretrained(self.base_model, trust_remote_code=True)
+ self._setup_tokenizer(tokenizer)
+ return tokenizer
+
+ def _setup_tokenizer(self, tokenizer):
+ tokenizer.chat_template = self.tokenizer_config["chat_template"]
+ tokenizer.pad_token = tokenizer.eos_token
+ tokenizer.pad_token_id = tokenizer.eos_token_id
+ tokenizer.padding_side = self.tokenizer_config["padding_side"]
diff --git a/code/rag/README.md b/code/rag/README.md
new file mode 100644
index 0000000..4a54cc8
--- /dev/null
+++ b/code/rag/README.md
@@ -0,0 +1,34 @@
+# Dense Retriever ์ฌ์ฉ ๊ฐ์ด๋
+
+์ด ๊ฐ์ด๋๋ Dense Retriever๋ฅผ ์ค์ ํ๊ณ ์ฌ์ฉํ๋ ๋ฐฉ๋ฒ์ ์ค๋ช
ํฉ๋๋ค.
+
+## ์ค๋น ๋จ๊ณ
+
+1. **์ํคํผ๋์ ๋คํ ํ์ผ ์ค๋น**
+ - `rag` ํด๋ ๋ด์ ์ํคํผ๋์ ๋คํ ํ์ผ์ ๋ค์ด๋ก๋ํ๊ณ ์์ถ์ ํด์ ํฉ๋๋ค.
+ - `text` ํด๋ ๋ด์ `AA`, `AB`, `AC` ํด๋๊ฐ ์กด์ฌํด์ผ ํฉ๋๋ค.
+ - ๊ฐ ํด๋ ์์ `wiki_`๋ก ์์ํ๋ ํ์ผ๋ค์ด ์์ด์ผ ํฉ๋๋ค.
+2. **KorQuAD_v1.0 ๋ฐ์ดํฐ์
์ค๋น**
+ - `data` ํด๋ ๋ด์ KorQuAD_v1.0_dev, KorQuAD_v1.0_train ํ์ผ์ ์ค๋นํด์ผํฉ๋๋ค.
+ - https://korquad.github.io/category/1.0_KOR.html
+3. **๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ**
+ - `prepare_dense.py` ์คํฌ๋ฆฝํธ๋ฅผ ์คํํฉ๋๋ค.
+ - ์คํ ํ ๋ค์ ํ์ผ๋ค์ด ์์ฑ๋์ด์ผ ํฉ๋๋ค:
+ - `preproccessed_passages/0-XXXX.p,XXXX-XXXX.p...`
+ - `titled_passage_map.p`
+ - `2050iter_flat/index_meta.dpr,index.dpr`
+ - ์์ฑ๋ ํ์ผ์ ์ฌ์ฉํ์ฌ ์์ ์ฟผ๋ฆฌ์ ๋ํ ์ ์ ํ ๋ฌธ์ ๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ ํ์ธํฉ๋๋ค.
+## ์ฌ์ฉ ๋ฐฉ๋ฒ
+
+1. **์ค์ ํ์ผ ์์ **
+ - `config` ํ์ผ์์ `retriever_type`์ `"DPR"`๋ก ์ค์ ํฉ๋๋ค.
+
+2. **์คํ**
+ - ๋ค์ ๋ช
๋ น์ด๋ฅผ ์คํํ์ฌ Dense Retriever๋ฅผ ์ฌ์ฉํฉ๋๋ค:
+ ```bash
+ python main.py
+ ```
+
+## ์ฃผ์์ฌํญ
+- ์ด ์ฝ๋๋ https://github.com/TmaxEdu/KorDPR๋ฅผ ์ฐธ๊ณ ํ์ฌ ์์ฑ๋์์ต๋๋ค.
+- ์์ ๋จ๊ณ๋ค์ ์์๋๋ก ์งํํด์ผ Dense Retriever๊ฐ ์ ์์ ์ผ๋ก ์๋ํฉ๋๋ค.
diff --git a/code/rag/__init__.py b/code/rag/__init__.py
new file mode 100644
index 0000000..ca0ab1e
--- /dev/null
+++ b/code/rag/__init__.py
@@ -0,0 +1,18 @@
+# # __init__.py ํ์ผ ๋ด์์ ํ์ํ ๋ชจ๋๋ค์ ์ํฌํธ
+from .chunk_data import DataChunk, save_orig_passage, save_title_index_map
+from .dpr_data import KorQuadDataset, KorQuadSampler, korquad_collator
+from .encoder import KobertBiEncoder
+
+# from .retriever_dense import DenseRetriever # ์ด ๋ถ๋ถ๋ ์ถ๊ฐํฉ๋๋ค.
+from .reranker import Reranker
+
+# #from .utils import get_wiki_filepath, wiki_worker_init # ๋ณ๊ฒฝ ์์
+# # import transformers
+# # # ์ธ๋ถ ์คํฌ๋ฆฝํธ์์ IndexRunner ์ํฌํธ
+# from .index_runner import IndexRunner
+# from .retriever import KorDPRRetriever # retriever.py์์ ๊ฐ์ ธ์ค๊ธฐ
+# from .indexers import DenseFlatIndexer
+# # # ์ถ๊ฐ๋ ๋ถ๋ถ
+from .retriever_bm25 import BM25Retriever
+from .retriever_elastic import ElasticsearchRetriever
+from .trainer import Trainer
diff --git a/code/rag/chunk_data.py b/code/rag/chunk_data.py
new file mode 100644
index 0000000..3bfe86f
--- /dev/null
+++ b/code/rag/chunk_data.py
@@ -0,0 +1,111 @@
+from collections import defaultdict
+from glob import glob
+import logging
+import os
+import pickle
+
+from tqdm import tqdm
+from transformers import AutoTokenizer
+
+
+os.makedirs("logs", exist_ok=True)
+logging.basicConfig(
+ filename="logs/log.log",
+ level=logging.DEBUG,
+ format="[%(asctime)s | %(funcName)s @ %(pathname)s] %(message)s",
+)
+logger = logging.getLogger()
+
+
+class DataChunk:
+ """์ธํ text๋ฅผ tokenizingํ ๋ค์ ์ฃผ์ด์ง ๊ธธ์ด๋ก chunking ํด์ ๋ฐํํฉ๋๋ค.
+ ์ด๋ ํ๋์ chunk(context, index ๋จ์)๋ ํ๋์ article์๋ง ์ํด์์ด์ผ ํฉ๋๋ค."""
+
+ def __init__(self, chunk_size=100):
+ self.chunk_size = chunk_size
+ self.tokenizer = AutoTokenizer.from_pretrained("monologg/kobert", trust_remote_code=True)
+
+ def chunk(self, input_file):
+ logger.info(f"Processing file: {input_file}")
+ with open(input_file, "rt", encoding="utf8") as f:
+ input_txt = f.read().strip()
+ input_txt = input_txt.split("")
+ chunk_list = []
+ orig_text = []
+ for art in input_txt:
+ art = art.strip()
+ if not art:
+ logger.debug("Article is empty, passing")
+ continue
+ title = art.split("\n")[0].strip(">").split("title=")[1].strip('"')
+ text = "\n".join(art.split("\n")[2:]).strip()
+
+ logger.debug(f"Processing article: {title}")
+
+ encoded_title = self.tokenizer.encode(title, add_special_tokens=True)
+ encoded_txt = self.tokenizer.encode(text, add_special_tokens=True)
+ if len(encoded_txt) < 5:
+ logger.debug(f"Title {title} has <5 subwords in its article, passing")
+ continue
+
+ for start_idx in range(0, len(encoded_txt), self.chunk_size):
+ end_idx = min(len(encoded_txt), start_idx + self.chunk_size)
+ chunk = encoded_title + encoded_txt[start_idx:end_idx]
+ orig_text.append(self.tokenizer.decode(chunk))
+ chunk_list.append(chunk)
+
+ logger.info(f"Processed {len(orig_text)} chunks from {input_file}.")
+ return orig_text, chunk_list
+
+
+def save_orig_passage(input_path="text", passage_path="processed_passages", chunk_size=100):
+ os.makedirs(passage_path, exist_ok=True)
+ app = DataChunk(chunk_size=chunk_size)
+ idx = 0
+ for path in tqdm(glob(f"{input_path}/*/wiki_*")):
+ ret, _ = app.chunk(path)
+ logger.info(f"Processed {len(ret)} chunks from {path}.") # ์ถ๊ฐ๋ ๋ก๊ทธ
+ if len(ret) > 0: # ์ฒญํฌ๊ฐ ์๋ ๊ฒฝ์ฐ์๋ง ์ ์ฅ
+ to_save = {idx + i: ret[i] for i in range(len(ret))}
+ with open(f"{passage_path}/{idx}-{idx+len(ret)-1}.p", "wb") as f:
+ pickle.dump(to_save, f)
+ idx += len(ret)
+
+
+def save_title_index_map(index_path="title_passage_map.p", source_passage_path="processed_passages"):
+ logging.getLogger()
+ logger.debug(f"Looking for files in {source_passage_path}")
+ files = glob(f"{source_passage_path}/*")
+ logger.debug(f"Found {len(files)} files")
+
+ title_id_map = defaultdict(list)
+ for f in tqdm(files):
+ logger.debug(f"Processing file: {f}")
+ with open(f, "rb") as _f:
+ id_passage_map = pickle.load(_f)
+
+ # ๋ก๊ทธ ์ถ๊ฐ: id_passage_map์ ํ์ ๋ฐ ๋ด์ฉ ํ์ธ
+ logger.debug(f"Loaded {len(id_passage_map)} passages from {f}")
+ logger.debug(f"Sample passage: {list(id_passage_map.items())[:5]}") # ์ฒซ 5๊ฐ ํญ๋ชฉ ์ถ๋ ฅ
+
+ for id, passage in id_passage_map.items():
+ parts = passage.split("[SEP]")
+ if len(parts) > 1:
+ title = parts[0].split("[CLS]")[1].strip()
+ title_id_map[title].append(id)
+ else:
+ logger.debug(f"Unexpected passage format in file {f}, id {id}")
+
+ logger.debug(f"Processed {len(id_passage_map)} passages from {f}...")
+
+ logger.debug(f"Total unique titles: {len(title_id_map)}")
+
+ with open(index_path, "wb") as f:
+ pickle.dump(title_id_map, f)
+
+ logger.debug(f"Finished saving title_index_mapping at {index_path}!")
+
+
+# if __name__ == "__main__":
+# save_orig_passage()
+# save_title_index_map()
diff --git a/code/rag/data_process/external_data.py b/code/rag/data_process/external_data.py
new file mode 100644
index 0000000..e393476
--- /dev/null
+++ b/code/rag/data_process/external_data.py
@@ -0,0 +1,198 @@
+import json
+import os
+from pathlib import Path
+import re
+import urllib.request
+
+from loguru import logger
+
+
+def preprocess_text(text):
+ # ํ๊ธ, ์ซ์, ํน์๋ฌธ์, ๊ณต๋ฐฑ๋ง ๋จ๊ธฐ๊ณ ๋๋จธ์ง ์ ๊ฑฐ
+ text = re.sub(r"\n", " ", text)
+ text = re.sub(r"\\n", " ", text)
+ text = re.sub(r"#", " ", text)
+ text = re.sub(r"\s+", " ", text).strip()
+ text = re.sub(r'[^ใฑ-ใ
๊ฐ-ํฃ0-9!"#%&\'(),-./:;<=>?@[\]^_`{|}~\s]', "", text)
+
+ # ๋ด์ฉ์ด ๋น ๊ดํธ ์ ๊ฑฐ
+ pattern = r"\(\s*\)"
+ while re.search(pattern, text):
+ text = re.sub(pattern, "", text)
+
+ return text
+
+
+def process_json_array(json_data):
+ # text ํ๋ ์ ์ฒ๋ฆฌ
+ if "text" in json_data:
+ json_data["text"] = preprocess_text(json_data["text"])
+
+ # title ํ๋ ์ ์ฒ๋ฆฌ
+ if "title" in json_data:
+ json_data["title"] = preprocess_text(json_data["title"])
+
+ return json_data
+
+
+def process_json_file(json_filename):
+ with open(json_filename, "r", encoding="utf-8") as f:
+ docs = json.load(f)
+
+ processed_docs = [process_json_array(item) for item in docs]
+
+ # ๋๋ ํ ๋ฆฌ์ ํ์ผ๋ช
๋ถ๋ฆฌ ํ ํ์ผ๋ช
์๋ง 'processed_' ์ถ๊ฐ
+ directory, filename = os.path.split(json_filename)
+ output_path = os.path.join(directory, "processed_" + filename)
+ with open(output_path, "w", encoding="utf-8") as f:
+ json.dump(processed_docs, f, ensure_ascii=False, indent=2)
+
+
+def _dump_wiki(data_path: str = "../data"):
+ """
+ ์ํคํผ๋์ ๋คํ๋ฅผ ๋ค์ด๋ก๋ํ๊ณ ์ถ์ถํ๋ ํจ์
+ """
+ dump_filename = "kowiki-latest-pages-articles.xml.bz2"
+ dump_path = os.path.join(data_path, dump_filename)
+ wiki_url = f"https://dumps.wikimedia.org/kowiki/latest/{dump_filename}"
+
+ # wget https://dumps.wikimedia.org/kowiki/latest/kowiki-latest-pages-articles.xml.bz2
+ if not os.path.exists(dump_path):
+ logger.debug(f"์ํคํผ๋์ ๋คํ๋ฅผ ๋ค์ด๋ก๋ํฉ๋๋ค: {wiki_url}")
+ urllib.request.urlretrieve(wiki_url, dump_path)
+ logger.debug(f"๋ค์ด๋ก๋ ์๋ฃ: {dump_path}")
+
+ # python -m wikiextractor.WikiExtractor kowiki-latest-pages-articles.xml.bz2
+ extract_dir = os.path.join(data_path, "text")
+ if not os.path.exists(extract_dir):
+ logger.debug("WikiExtractor๋ก ๋คํ ํ์ผ์ ์ถ์ถํฉ๋๋ค...")
+ os.system(f"python -m wikiextractor.WikiExtractor {dump_path} -o {extract_dir}")
+ logger.debug("์ถ์ถ ์๋ฃ")
+
+ def _get_filename_list(dirname):
+ filepaths = []
+ for root, dirs, files in os.walk(dirname):
+ for file in files:
+ filepath = os.path.join(root, file)
+ if re.match(r"wiki_[0-9][0-9]", file):
+ filepaths.append(filepath)
+ return sorted(filepaths)
+
+ filepaths = _get_filename_list(extract_dir)
+ output_path = os.path.join(data_path, "wiki_dump.txt")
+
+ # ํ์ผ ๋ด์ฉ ์ฝ๊ธฐ
+ all_text = ""
+ for filepath in filepaths:
+ with open(filepath, "r", encoding="utf-8") as f:
+ all_text += f.read() + "\n"
+
+ # ์ ์ฒด ํ
์คํธ๋ฅผ ํ๋์ ํ์ผ๋ก ์ ์ฅ
+ with open(output_path, "w", encoding="utf-8") as f:
+ f.write(all_text)
+
+ logger.debug(f"์ด {len(filepaths)}๊ฐ์ ํ์ผ์ ์ฒ๋ฆฌํ์ต๋๋ค.")
+ logger.debug(f"๋ชจ๋ ๋ด์ฉ์ด {output_path} ํ์ผ์ ์ ์ฅ๋์์ต๋๋ค.")
+
+
+def _parse_wiki_dump(file_path: str = "../data/wiki_dump.txt"):
+ """
+ ์ํคํผ๋์ ๋คํ ํ์ผ์ JSON ํ์์ผ๋ก ๋ณํํ๋ ํจ์
+ """
+ documents = []
+ current_doc = ""
+ doc_id = None
+ title = None
+
+ with open(file_path, "r", encoding="utf-8") as f:
+ for line in f:
+ # ์๋ก์ด ๋ฌธ์ ์์
+ if line.startswith(""):
+ if current_doc.strip(): # ๋น ๋ฌธ์๊ฐ ์๋ ๊ฒฝ์ฐ๋ง ์ถ๊ฐ
+ documents.append({"id": doc_id, "title": title, "text": current_doc.strip()})
+ # ๋ฌธ์ ๋ด์ฉ
+ else:
+ current_doc += line
+
+ # JSON ํ์ผ๋ก ์ ์ฅ
+ output_path = file_path.replace(".txt", ".json")
+ with open(output_path, "w", encoding="utf-8") as json_file:
+ json.dump(documents, json_file, ensure_ascii=False, indent=4)
+
+ logger.debug(f"JSON ํ์ผ์ด ์์ฑ๋์์ต๋๋ค: {output_path}")
+ logger.debug(f"์ด {len(documents)}๊ฐ์ ๋ฌธ์๊ฐ ์ฒ๋ฆฌ๋์์ต๋๋ค.")
+ return documents
+
+
+def wikipedia():
+ """
+ ์ํคํผ๋์ ํ๊ตญ์ด ๋คํ ๋ฌธ์๋ฅผ ๊ฐ์ ธ์ค๊ณ ํ์ฑํ์ฌ ํ๋์ JSON ํ์ผ ์์ฑ
+ """
+ _dump_wiki()
+ _parse_wiki_dump()
+
+
+def ai_hub_news_corpus(input_dir: str, output_file: str):
+ """
+ ๋๊ท๋ชจ ์น๋ฐ์ดํฐ ๊ธฐ๋ฐ ํ๊ตญ์ด ๋ง๋ญ์น ๋ฐ์ดํฐ
+ \n https://www.aihub.or.kr/aihubdata/data/view.do?currMenu=115&topMenu=100&dataSetSn=624
+ \n ์ง์ ๋ ๋๋ ํ ๋ฆฌ์ ๋ชจ๋ JSON ํ์ผ์ ์ฒ๋ฆฌํ์ฌ ํ๋์ JSON ํ์ผ๋ก ํตํฉ
+ Args:
+ input_dir: ์
๋ ฅ JSON ํ์ผ๋ค์ด ์๋ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก
+ output_file: ์ถ๋ ฅ๋ ํตํฉ JSON ํ์ผ ๊ฒฝ๋ก
+ """
+ all_documents = []
+ input_path = Path(input_dir)
+
+ try:
+ # ์
๋ ฅ ๋๋ ํ ๋ฆฌ ๋ด์ ๋ชจ๋ JSON ํ์ผ ์ฒ๋ฆฌ
+ for json_file in input_path.glob("**/*.json"):
+ logger.info(f"์ฒ๋ฆฌ ์ค์ธ ํ์ผ: {json_file}")
+
+ try:
+ with open(json_file, "r", encoding="utf-8") as f:
+ data = json.load(f)
+
+ # SJML ๊ตฌ์กฐ ํ์ธ ๋ฐ ๋ฐ์ดํฐ ์ถ์ถ
+ if "SJML" in data and "text" in data["SJML"]:
+ for doc in data["SJML"]["text"]:
+ processed_doc = {"title": doc["title"], "text": doc["content"]}
+ all_documents.append(processed_doc)
+ else:
+ logger.warning(f"์๋ชป๋ JSON ๊ตฌ์กฐ: {json_file}")
+
+ except json.JSONDecodeError:
+ logger.error(f"JSON ํ์ฑ ์ค๋ฅ: {json_file}")
+ except Exception as e:
+ logger.error(f"ํ์ผ ์ฒ๋ฆฌ ์ค ์ค๋ฅ ๋ฐ์: {json_file}, ์ค๋ฅ: {str(e)}")
+
+ # ์ต์ข
๊ฒฐ๊ณผ๋ฅผ ๋จ์ผ JSON ํ์ผ๋ก ์ ์ฅ
+ if all_documents:
+ with open(output_file, "w", encoding="utf-8") as f:
+ json.dump(all_documents, f, ensure_ascii=False, indent=2)
+ logger.info(f"์ฒ๋ฆฌ ์๋ฃ: ์ด {len(all_documents)}๊ฐ ๋ฌธ์๊ฐ {output_file}์ ์ ์ฅ๋จ")
+ else:
+ logger.warning("์ฒ๋ฆฌ๋ ๋ฌธ์๊ฐ ์์ต๋๋ค.")
+
+ except Exception as e:
+ logger.error(f"์ ์ฒด ์ฒ๋ฆฌ ๊ณผ์ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}")
+
+
+if __name__ == "__main__":
+ os.chdir("../../")
+
+ PROCESS_JSON_FILE = False
+ WIKIPEDIA = False
+ AI_HUB_NEWS_CORPUS = False
+
+ if PROCESS_JSON_FILE:
+ process_json_file("../data/documents.json")
+ if WIKIPEDIA:
+ wikipedia()
+ if AI_HUB_NEWS_CORPUS:
+ ai_hub_news_corpus("../data/ai_hub_news_corpus", "../data/ai_hub_news_corpus.json")
diff --git a/code/rag/data_process/wiki_dump.py b/code/rag/data_process/wiki_dump.py
new file mode 100644
index 0000000..a5ad05e
--- /dev/null
+++ b/code/rag/data_process/wiki_dump.py
@@ -0,0 +1,93 @@
+import json
+import os
+import re
+import urllib.request
+
+from loguru import logger
+
+
+def dump_wiki(data_path: str = "../data"):
+ """
+ ์ํคํผ๋์ ๋คํ๋ฅผ ๋ค์ด๋ก๋ํ๊ณ ์ถ์ถํ๋ ํจ์
+ """
+ dump_filename = "kowiki-latest-pages-articles.xml.bz2"
+ dump_path = os.path.join(data_path, dump_filename)
+ wiki_url = f"https://dumps.wikimedia.org/kowiki/latest/{dump_filename}"
+
+ # wget https://dumps.wikimedia.org/kowiki/latest/kowiki-latest-pages-articles.xml.bz2
+ if not os.path.exists(dump_path):
+ logger.debug(f"์ํคํผ๋์ ๋คํ๋ฅผ ๋ค์ด๋ก๋ํฉ๋๋ค: {wiki_url}")
+ urllib.request.urlretrieve(wiki_url, dump_path)
+ logger.debug(f"๋ค์ด๋ก๋ ์๋ฃ: {dump_path}")
+
+ # python -m wikiextractor.WikiExtractor kowiki-latest-pages-articles.xml.bz2
+ extract_dir = os.path.join(data_path, "text")
+ if not os.path.exists(extract_dir):
+ logger.debug("WikiExtractor๋ก ๋คํ ํ์ผ์ ์ถ์ถํฉ๋๋ค...")
+ os.system(f"python -m wikiextractor.WikiExtractor {dump_path} -o {extract_dir}")
+ logger.debug("์ถ์ถ ์๋ฃ")
+
+ def _get_filename_list(dirname):
+ filepaths = []
+ for root, dirs, files in os.walk(dirname):
+ for file in files:
+ filepath = os.path.join(root, file)
+ if re.match(r"wiki_[0-9][0-9]", file):
+ filepaths.append(filepath)
+ return sorted(filepaths)
+
+ filepaths = _get_filename_list(extract_dir)
+ output_path = os.path.join(data_path, "wiki_dump.txt")
+
+ # ํ์ผ ๋ด์ฉ ์ฝ๊ธฐ
+ all_text = ""
+ for filepath in filepaths:
+ with open(filepath, "r", encoding="utf-8") as f:
+ all_text += f.read() + "\n"
+
+ # ์ ์ฒด ํ
์คํธ๋ฅผ ํ๋์ ํ์ผ๋ก ์ ์ฅ
+ with open(output_path, "w", encoding="utf-8") as f:
+ f.write(all_text)
+
+ logger.debug(f"์ด {len(filepaths)}๊ฐ์ ํ์ผ์ ์ฒ๋ฆฌํ์ต๋๋ค.")
+ logger.debug(f"๋ชจ๋ ๋ด์ฉ์ด {output_path} ํ์ผ์ ์ ์ฅ๋์์ต๋๋ค.")
+
+
+def parse_wiki_dump(file_path: str = "../data/wiki_dump.txt"):
+ """
+ ์ํคํผ๋์ ๋คํ ํ์ผ์ JSON ํ์์ผ๋ก ๋ณํํ๋ ํจ์
+ """
+ documents = []
+ current_doc = ""
+ doc_id = None
+ title = None
+
+ with open(file_path, "r", encoding="utf-8") as f:
+ for line in f:
+ # ์๋ก์ด ๋ฌธ์ ์์
+ if line.startswith(""):
+ if current_doc.strip(): # ๋น ๋ฌธ์๊ฐ ์๋ ๊ฒฝ์ฐ๋ง ์ถ๊ฐ
+ documents.append({"id": doc_id, "title": title, "text": current_doc.strip()})
+ # ๋ฌธ์ ๋ด์ฉ
+ else:
+ current_doc += line
+
+ # JSON ํ์ผ๋ก ์ ์ฅ
+ output_path = file_path.replace(".txt", ".json")
+ with open(output_path, "w", encoding="utf-8") as json_file:
+ json.dump(documents, json_file, ensure_ascii=False, indent=4)
+
+ logger.debug(f"JSON ํ์ผ์ด ์์ฑ๋์์ต๋๋ค: {output_path}")
+ logger.debug(f"์ด {len(documents)}๊ฐ์ ๋ฌธ์๊ฐ ์ฒ๋ฆฌ๋์์ต๋๋ค.")
+ return documents
+
+
+if __name__ == "__main__":
+ os.chdir("../../")
+ dump_wiki()
+ parse_wiki_dump()
diff --git a/code/rag/dpr_data.py b/code/rag/dpr_data.py
new file mode 100644
index 0000000..b13155f
--- /dev/null
+++ b/code/rag/dpr_data.py
@@ -0,0 +1,226 @@
+# from utils import get_passage_file
+from glob import glob
+import json
+import logging
+import math
+import os
+import pickle
+import re
+import typing
+from typing import Iterator, List, Sized, Tuple
+
+import torch
+from torch import tensor as T
+from torch.nn.utils.rnn import pad_sequence
+from tqdm import tqdm
+from transformers import AutoTokenizer
+
+
+def get_wiki_filepath(data_dir):
+ return glob(f"{data_dir}/*/wiki_*")
+
+
+def wiki_worker_init(worker_id):
+ worker_info = torch.utils.data.get_worker_info()
+ dataset = worker_info.dataset
+ # logger.debug(dataset)
+ # dataset =
+ overall_start = dataset.start
+ overall_end = dataset.end
+ split_size = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
+ worker_id = worker_info.id
+ # end_idx = min((worker_id+1) * split_size, len(dataset.data))
+ dataset.start = overall_start + worker_id * split_size
+ dataset.end = min(dataset.start + split_size, overall_end) # index error ๋ฐฉ์ง
+
+
+def get_passage_file(p_id_list: typing.List[int]) -> str:
+ """passage id๋ฅผ ๋ฐ์์ ํด๋น๋๋ ํ์ผ ์ด๋ฆ์ ๋ฐํํฉ๋๋ค."""
+ target_file = None
+ p_id_max = max(p_id_list)
+ p_id_min = min(p_id_list)
+ for f in glob("processed_passages/*.p"):
+ s, e = f.split("/")[1].split(".")[0].split("-")
+ s, e = int(s), int(e)
+ if p_id_min >= s and p_id_max <= e:
+ target_file = f
+ return target_file
+
+
+# set logger
+os.makedirs("logs", exist_ok=True)
+logging.basicConfig(
+ filename="logs/log.log",
+ level=logging.DEBUG,
+ format="[%(asctime)s | %(funcName)s @ %(pathname)s] %(message)s",
+)
+logger = logging.getLogger()
+
+
+def korquad_collator(batch: List[Tuple], padding_value: int) -> Tuple[torch.Tensor]:
+ """query, p_id, gold_passage๋ฅผ batch๋ก ๋ฐํํฉ๋๋ค."""
+ batch_q = pad_sequence([T(e[0]) for e in batch], batch_first=True, padding_value=padding_value)
+ # logger.debug(batch_q.shape)
+ batch_q_attn_mask = (batch_q != padding_value).long()
+ # logger.debug(batch_q_attn_mask.shape)
+ batch_p_id = T([e[1] for e in batch])[:, None]
+ # logger.debug(batch_p_id.shape)
+ batch_p = pad_sequence([T(e[2]) for e in batch], batch_first=True, padding_value=padding_value)
+ # logger.debug(batch_p.shape)
+ batch_p_attn_mask = (batch_p != padding_value).long()
+ return (batch_q, batch_q_attn_mask, batch_p_id, batch_p, batch_p_attn_mask)
+
+
+class KorQuadSampler(torch.utils.data.BatchSampler):
+ """in-batch negativeํ์ต์ ์ํด batch ๋ด์ ์ค๋ณต answer๋ฅผ ๊ฐ์ง ์๋๋ก batch๋ฅผ ๊ตฌ์ฑํฉ๋๋ค.
+ sample ์ผ๋ถ๋ฅผ passํ๊ธฐ ๋๋ฌธ์ ์ ์ฒด data ์๋ณด๋ค iteration์ ํตํด ๋์ค๋ ๋ฐ์ดํฐ ์๊ฐ ๋ช์ญ๊ฐ ์ ๋ ์ ์ต๋๋ค."""
+
+ def __init__(
+ self,
+ data_source: Sized,
+ batch_size: int,
+ drop_last: bool = False,
+ shuffle: bool = True,
+ generator=None,
+ ) -> None:
+ if shuffle:
+ sampler = torch.utils.data.RandomSampler(data_source, replacement=False, generator=generator)
+ else:
+ sampler = torch.utils.data.SequentialSampler(data_source)
+ super(KorQuadSampler, self).__init__(sampler=sampler, batch_size=batch_size, drop_last=drop_last)
+
+ def __iter__(self) -> Iterator[List[int]]:
+ sampled_p_id = []
+ sampled_idx = []
+ for idx in self.sampler:
+ item = self.sampler.data_source[idx]
+ if item[1] in sampled_p_id:
+ continue # ๋ง์ผ ๊ฐ์ answer passage๊ฐ ์ด๋ฏธ ๋ฝํ๋ค๋ฉด pass
+ sampled_idx.append(idx)
+ sampled_p_id.append(item[1])
+ if len(sampled_idx) >= self.batch_size:
+ yield sampled_idx
+ sampled_p_id = []
+ sampled_idx = []
+ if len(sampled_idx) > 0 and not self.drop_last:
+ yield sampled_idx
+
+
+class KorQuadDataset:
+ def __init__(self, korquad_path: str, title_passage_map_path="title_passage_map.p"):
+ self.korquad_path = korquad_path
+ self.data_tuples = []
+ self.tokenizer = AutoTokenizer.from_pretrained("monologg/kobert", trust_remote_code=True)
+ self.pad_token_id = self.tokenizer.get_vocab()["[PAD]"]
+ self.load()
+
+ @property
+ def dataset(self) -> List[Tuple]:
+ return self.tokenized_tuples
+
+ def stat(self):
+ """korquad ๋ฐ์ดํฐ์
์ ์คํฏ์ ์ถ๋ ฅํฉ๋๋ค."""
+ raise NotImplementedError()
+
+ def load(self):
+ """๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ๊ฐ ์๋ฃ๋์๋ค๋ฉด loadํ๊ณ ๊ทธ๋ ์ง ์์ผ๋ฉด ์ ์ฒ๋ฆฌ๋ฅผ ์ํํฉ๋๋ค."""
+ self.korquad_processed_path = f"{self.korquad_path.split('.json')[0]}_processed.p"
+ if os.path.exists(self.korquad_processed_path):
+ logger.debug("preprocessed file already exists, loading...")
+ with open(self.korquad_processed_path, "rb") as f:
+ self.tokenized_tuples = pickle.load(f)
+ logger.debug("successfully loaded tokenized_tuples into self.tokenized_tuples")
+
+ else:
+ self._load_data()
+ self._match_passage()
+ logger.debug("successfully loaded data_tuples into self.data_tuples")
+ # tokenizing raw dataset
+ self.tokenized_tuples = [
+ (self.tokenizer.encode(q), id, self.tokenizer.encode(p))
+ for q, id, p in tqdm(self.data_tuples, desc="tokenize")
+ ]
+ self._save_processed_dataset()
+ logger.debug("finished tokenization")
+
+ def _load_data(self):
+ with open(self.korquad_path, "rt", encoding="utf8") as f:
+ data = json.load(f)
+ self.raw_json = data["data"]
+ logger.debug("data loaded into self.raw_json")
+ with open("title_passage_map.p", "rb") as f:
+ self.title_passage_map = pickle.load(f)
+ logger.debug("title passage mapping loaded into self.title_passage_map")
+
+ def _get_cand_ids(self, title):
+ """๋ฏธ๋ฆฌ ๊ตฌ์ถํ ko-wiki ๋ฐ์ดํฐ์์ ํด๋น title์ ๋ง๋ id๋ค์ ๊ฐ์ง๊ณ ์ต๋๋ค."""
+ refined_title = None
+ ret = self.title_passage_map.get(title, None)
+ if not ret:
+ refined_title = re.sub(r"\(.*\)", "", title).strip()
+ ret = self.title_passage_map.get(refined_title, None)
+ return ret, refined_title
+
+ def _match_passage(self):
+ """๋ฏธ๋ฆฌ ๊ตฌ์ถํ ko-wiki ๋ฐ์ดํฐ์ korQuad์ answer๋ฅผ ๋งค์นญํ์ฌ
+ (query, passage_id, passage)์ tuple์ ๊ตฌ์ฑํฉ๋๋ค."""
+ for item in tqdm(self.raw_json, desc="matching silver passages"):
+ title = item["title"].replace("_", " ") # _๋ฅผ ๊ณต๋ฐฑ๋ฌธ์๋ก ๋ณ๊ฒฝ
+ para = item["paragraphs"]
+ cand_ids, refined_title = self._get_cand_ids(title)
+ if refined_title is not None and cand_ids:
+ logger.debug(f"refined the title and proceed : {title} -> {refined_title}")
+ if cand_ids is None:
+ logger.debug(f"No such title as {title} or {refined_title}. passing this title")
+ continue
+ target_file_p = get_passage_file(cand_ids)
+ if target_file_p is None:
+ logger.debug(f"No single target file for {title}, got passage ids {cand_ids}. passing this title")
+ continue
+ with open(target_file_p, "rb") as f:
+ target_file = pickle.load(f)
+ contexts = {cand_id: target_file[cand_id] for cand_id in cand_ids}
+
+ for p in para:
+ qas = p["qas"]
+ for qa in qas:
+ answer = qa["answers"][0]["text"] # ์๋ฌด ์ ๋ต์ด๋ ๋ฝ์ต๋๋ค.
+ answer_pos = qa["answers"][0]["answer_start"]
+ answer_clue_start = max(0, answer_pos - 5)
+ answer_clue_end = min(len(p["context"]), answer_pos + len(answer) + 5)
+ answer_clue = p["context"][
+ answer_clue_start:answer_clue_end
+ ] # gold passage๋ฅผ ์ฐพ๊ธฐ ์ํด์ +-5์นธ์ ์ฃผ๋ณ text ํ์ฉ
+ question = qa["question"]
+ answer_p = [
+ (p_id, c) for p_id, c in contexts.items() if answer_clue in c
+ ] # answer๊ฐ ๋จ์ํ ๋ค์ด์๋ ๋ฌธ์๋ฅผ ๋ฝ๋๋ค.
+ if not answer_p:
+ answer_p = [(p_id, c) for p_id, c in contexts.items() if answer in c]
+
+ self.data_tuples.extend([(question, p_id, c) for p_id, c in answer_p])
+
+ def _save_processed_dataset(self):
+ """์ ์ฒ๋ฆฌํ ๋ฐ์ดํฐ๋ฅผ ์ ์ฅํฉ๋๋ค."""
+ with open(self.korquad_processed_path, "wb") as f:
+ pickle.dump(self.tokenized_tuples, f)
+ logger.debug(f"successfully saved self.tokenized_tuples into {self.korquad_processed_path}")
+
+
+# if __name__ == "__main__":
+# ds = KorQuadDataset(korquad_path="./data/KorQuAD_v1.0_train.json")
+
+# loader = torch.utils.data.DataLoader(
+# dataset=ds.dataset,
+# batch_sampler=KorQuadSampler(ds.dataset, batch_size=16, drop_last=False),
+# collate_fn=lambda x: korquad_collator(x, padding_value=ds.pad_token_id),
+# num_workers=4,
+# )
+# # logger.debug(len(_dataset.tokenized_tuples))
+# torch.manual_seed(123412341235)
+# cnt = 0
+# for batch in tqdm(loader):
+# #logger.debug(len(batch))
+# cnt += batch[0].size(0)
+# # break
+# logger.debug(cnt)
diff --git a/code/rag/encoder.py b/code/rag/encoder.py
new file mode 100644
index 0000000..8bb1593
--- /dev/null
+++ b/code/rag/encoder.py
@@ -0,0 +1,62 @@
+from copy import deepcopy
+import logging
+import os
+
+import torch
+from transformers import BertModel
+
+
+# ๋ ๊ฐ์ BertModel์ ์ฌ์ฉํ์ฌ passage์ query๋ฅผ encoding์ ์คํ
+# ํ ํฌ๋์ด์ง ํ์ ํ ํฐ์ ๊ณ ์ ๋ ํฌ๊ธฐ์ ๋ฒกํฐ๋ก ๋ณ๊ฒฝ
+
+# ๋ก๊ทธ ๋๋ ํ ๋ฆฌ ์์ฑ (์์ผ๋ฉด ์๋ก ์์ฑ)
+os.makedirs("logs", exist_ok=True)
+
+# ๋ก๊น
์ค์ : ๋ก๊ทธ๋ฅผ ํ์ผ๋ก ์ ์ฅํ๊ณ ๋๋ฒ๊น
๋ ๋ฒจ๋ก ์ค์
+logging.basicConfig(
+ filename="logs/log.log",
+ level=logging.DEBUG,
+ format="[%(asctime)s | %(funcName)s @ %(pathname)s] %(message)s",
+)
+logger = logging.getLogger()
+
+
+# KobertBiEncoder ํด๋์ค ์ ์
+class KobertBiEncoder(torch.nn.Module):
+ def __init__(self):
+ # torch.nn.Module์ ์ด๊ธฐํ ํจ์ ํธ์ถ
+ super(KobertBiEncoder, self).__init__()
+ # passage(๋ฌธ์)๋ฅผ ์ฒ๋ฆฌํ๋ BERT ๋ชจ๋ธ
+ self.passage_encoder = BertModel.from_pretrained("monologg/kobert", trust_remote_code=True)
+ # query(์ง์)๋ฅผ ์ฒ๋ฆฌํ๋ BERT ๋ชจ๋ธ
+ self.query_encoder = BertModel.from_pretrained("monologg/kobert", trust_remote_code=True)
+ # BERT ๋ชจ๋ธ์ pooler output(์๋ฒ ๋ฉ ํฌ๊ธฐ) ์ค์
+ self.emb_sz = self.passage_encoder.pooler.dense.out_features # get cls token dim
+
+ def forward(self, x: torch.LongTensor, attn_mask: torch.LongTensor, type: str = "passage") -> torch.FloatTensor:
+ """passage ๋๋ query๋ฅผ BERT๋ก ์ธ์ฝ๋ฉํฉ๋๋ค."""
+ # type์ด 'passage' ๋๋ 'query'์ธ์ง ํ์ธ
+ assert type in (
+ "passage",
+ "query",
+ ), "type should be either 'passage' or 'query'"
+ # type์ ๋ฐ๋ผ ๋ค๋ฅธ ์ธ์ฝ๋ ์ฌ์ฉ
+ if type == "passage":
+ # ๋ฌธ์(passage) ์ธ์ฝ๋ฉ
+ return self.passage_encoder(input_ids=x, attention_mask=attn_mask).pooler_output
+ else:
+ # ์ง์(query) ์ธ์ฝ๋ฉ
+ return self.query_encoder(input_ids=x, attention_mask=attn_mask).pooler_output
+
+ def checkpoint(self, model_ckpt_path):
+ # ๋ชจ๋ธ์ ๊ฐ์ค์น๋ฅผ ํ์ผ๋ก ์ ์ฅ
+ torch.save(deepcopy(self.state_dict()), model_ckpt_path)
+ logger.debug(f"model self.state_dict saved to {model_ckpt_path}")
+
+ def load(self, model_ckpt_path):
+ # ์ ์ฅ๋ ๊ฐ์ค์น๋ฅผ ํ์ผ์์ ๋ก๋
+ with open(model_ckpt_path, "rb") as f:
+ state_dict = torch.load(f)
+ # ๋ชจ๋ธ์ ๋ก๋๋ ๊ฐ์ค์น ์ ์ฉ
+ self.load_state_dict(state_dict)
+ logger.debug(f"model self.state_dict loaded from {model_ckpt_path}")
diff --git a/code/rag/index_runner.py b/code/rag/index_runner.py
new file mode 100644
index 0000000..6a28f73
--- /dev/null
+++ b/code/rag/index_runner.py
@@ -0,0 +1,188 @@
+from glob import glob
+import logging
+import math
+import os
+import typing
+from typing import List, Tuple
+
+from chunk_data import DataChunk
+from encoder import KobertBiEncoder
+import indexers
+import torch
+from torch import tensor as T
+from torch.nn.utils.rnn import pad_sequence
+from tqdm import tqdm
+import transformers
+
+
+# from utils import get_wiki_filepath, wiki_worker_init
+transformers.logging.set_verbosity_error() # ํ ํฌ๋์ด์ ์ด๊ธฐํ ๊ด๋ จ warning suppress
+
+
+def get_wiki_filepath(data_dir):
+ return glob(f"{data_dir}/*/wiki_*")
+
+
+def wiki_worker_init(worker_id):
+ worker_info = torch.utils.data.get_worker_info()
+ dataset = worker_info.dataset
+ # print(dataset)
+ # dataset =
+ overall_start = dataset.start
+ overall_end = dataset.end
+ split_size = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
+ worker_id = worker_info.id
+ # end_idx = min((worker_id+1) * split_size, len(dataset.data))
+ dataset.start = overall_start + worker_id * split_size
+ dataset.end = min(dataset.start + split_size, overall_end) # index error ๋ฐฉ์ง
+
+
+def get_passage_file(p_id_list: typing.List[int]) -> str:
+ """passage id๋ฅผ ๋ฐ์์ ํด๋น๋๋ ํ์ผ ์ด๋ฆ์ ๋ฐํํฉ๋๋ค."""
+ target_file = None
+ p_id_max = max(p_id_list)
+ p_id_min = min(p_id_list)
+ for f in glob("processed_passages/*.p"):
+ s, e = f.split("/")[1].split(".")[0].split("-")
+ s, e = int(s), int(e)
+ if p_id_min >= s and p_id_max <= e:
+ target_file = f
+ return target_file
+
+
+# logger basic config
+os.makedirs("logs", exist_ok=True)
+logging.basicConfig(
+ filename="logs/log.log",
+ level=logging.DEBUG,
+ format="[%(asctime)s | %(funcName)s @ %(pathname)s] %(message)s",
+)
+logger = logging.getLogger()
+
+
+def wiki_collator(batch: List, padding_value: int) -> Tuple[torch.Tensor]:
+ """passage๋ฅผ batch๋ก ๋ฐํํฉ๋๋ค."""
+ batch_p = pad_sequence([T(e) for e in batch], batch_first=True, padding_value=padding_value)
+ batch_p_attn_mask = (batch_p != padding_value).long()
+ return (batch_p, batch_p_attn_mask)
+
+
+class WikiArticleStream(torch.utils.data.IterableDataset):
+ """
+ Indexing์ ์ํด random access๊ฐ ํ์ํ์ง ์๊ณ large corpus๋ฅผ ๋ค๋ฃจ๊ธฐ ์ํด stream dataset์ ์ฌ์ฉํฉ๋๋ค.
+ """
+
+ def __init__(self, wiki_path, chunker):
+ # self.chunk_size = chunk_size
+ super(WikiArticleStream, self).__init__()
+ self.chunker = chunker
+ self.pad_token_id = self.chunker.tokenizer.get_vocab()["[PAD]"]
+ self.wiki_path = wiki_path
+ self.max_length = 168 # maximum length for kowiki passage
+ # self.start = 0
+ # self.end = len(self.passages)
+
+ def __iter__(self):
+ # max_length๊ฐ ๋๋๋ก padding ์ํ
+
+ _, passages = self.chunker.chunk(self.wiki_path)
+ logger.debug(f"chunked file {self.wiki_path}")
+ for passage in passages:
+ # if len(passage) > self.max_length:
+ # continue # ์ง์ ๋ max_length๋ณด๋ค ๊ธด passage์ ๊ฒฝ์ฐ pass
+ # padded_passage = T(
+ # passage
+ # + [self.pad_token_id for _ in range(self.max_length - len(passage))]
+ # )
+ yield passage
+
+
+class IndexRunner:
+ """์ฝํผ์ค์ ๋ํ ์ธ๋ฑ์ฑ์ ์ํํ๋ ๋ฉ์ธํด๋์ค์
๋๋ค.
+ passage encoder์ data loader, FAISS indexer๋ก ๊ตฌ์ฑ๋์ด ์์ต๋๋ค."""
+
+ def __init__(
+ self,
+ data_dir: str,
+ model_ckpt_path: str,
+ indexer_type: str = "DenseFlatIndexer",
+ chunk_size: int = 100,
+ batch_size: int = 64,
+ buffer_size: int = 50000,
+ index_output: str = "",
+ device: str = "cuda:0",
+ ):
+ """
+ data_dir : ์ธ๋ฑ์ฑํ ํ๊ตญ์ด wiki ๋ฐ์ดํฐ๊ฐ ๋ค์ด์๋ ๋๋ ํ ๋ฆฌ์
๋๋ค. ํ์์ AA, AB์ ๊ฐ์ ๋๋ ํ ๋ฆฌ๊ฐ ์์ต๋๋ค.
+ indexer_type : ์ฌ์ฉํ FAISS indexer ์ข
๋ฅ๋ก
+ DPR ๋ฆฌํฌ์ ์๋ ๋๋ก Flat, HNSWFlat, HNSWSQ ์ธ ์ข
๋ฅ ์ค์ ์ฌ์ฉํ ์ ์์ต๋๋ค.
+ chunk_size : indexing๊ณผ searching์ ๋จ์๊ฐ ๋๋ passage์ ๊ธธ์ด์
๋๋ค.
+ DPR ๋
ผ๋ฌธ์์๋ 100๊ฐ token ๊ธธ์ด + title๋ก ํ๋์ passage๋ฅผ ์ ์ํ์์ต๋๋ค.
+ """
+ if "=" in data_dir:
+ self.data_dir, self.to_this_page = data_dir.split("=")
+ self.to_this_page = int(self.to_this_page)
+ self.wiki_files = get_wiki_filepath(self.data_dir)
+ else:
+ self.data_dir = data_dir
+ self.wiki_files = get_wiki_filepath(self.data_dir)
+ self.to_this_page = len(self.wiki_files)
+
+ self.device = torch.device(device)
+ self.encoder = KobertBiEncoder().to(self.device)
+ self.encoder.load(model_ckpt_path) # loading model
+ self.indexer = getattr(indexers, indexer_type)()
+ self.chunk_size = chunk_size
+ self.batch_size = batch_size
+ self.buffer_size = buffer_size
+ self.loader = self.get_loader(
+ self.wiki_files[: self.to_this_page],
+ chunk_size,
+ batch_size,
+ worker_init_fn=None,
+ )
+ self.indexer.init_index(self.encoder.emb_sz)
+ self.index_output = index_output if index_output else indexer_type
+
+ @staticmethod
+ def get_loader(wiki_files, chunk_size, batch_size, worker_init_fn=None):
+ chunker = DataChunk(chunk_size=chunk_size)
+ ds = torch.utils.data.ChainDataset(tuple(WikiArticleStream(path, chunker) for path in wiki_files))
+ loader = torch.utils.data.DataLoader(
+ ds,
+ batch_size=batch_size,
+ collate_fn=lambda x: wiki_collator(x, padding_value=chunker.tokenizer.get_vocab()["[PAD]"]),
+ num_workers=1,
+ worker_init_fn=worker_init_fn,
+ ) # TODO : chain dataset์์ worker 1 ์ด๊ณผ์ธ ๊ฒฝ์ฐ ํ์ธํ๊ธฐ
+ return loader
+
+ def run(self):
+ _to_index = []
+ cur = 0
+ for batch in tqdm(self.loader, desc="indexing"):
+ p, p_mask = batch
+ p, p_mask = p.to(self.device), p_mask.to(self.device)
+ with torch.no_grad():
+ p_emb = self.encoder(p, p_mask, "passage")
+ _to_index += [(cur + i, _emb) for i, _emb in enumerate(p_emb.cpu().numpy())]
+ cur += p_emb.size(0)
+ if len(_to_index) > self.buffer_size - self.batch_size:
+ logger.debug(f"perform indexing... {len(_to_index)} added")
+ self.indexer.index_data(_to_index)
+ _to_index = []
+ if _to_index:
+ logger.debug(f"perform indexing... {len(_to_index)} added")
+ self.indexer.index_data(_to_index)
+ _to_index = []
+ os.makedirs(self.index_output, exist_ok=True)
+ self.indexer.serialize(self.index_output)
+
+
+# if __name__ == "__main__":
+# IndexRunner(
+# data_dir="./dataset/text",
+# model_ckpt_path="./my_model.pt",
+# index_output="2050iter_flat",
+# ).run()
+# # test_loader()
diff --git a/code/rag/indexers.py b/code/rag/indexers.py
new file mode 100644
index 0000000..e503a28
--- /dev/null
+++ b/code/rag/indexers.py
@@ -0,0 +1,233 @@
+# Credit : facebookresearch/DPR
+
+"""
+FAISS-based index components for dense retriever
+"""
+
+import logging
+import os
+import pickle
+from typing import List, Tuple
+
+import faiss
+import numpy as np
+
+
+logger = logging.getLogger()
+
+
+class DenseIndexer(object):
+ def __init__(self, buffer_size: int = 50000):
+ """
+ ์ธ๋ฑ์๋ฅผ ์ด๊ธฐํํฉ๋๋ค.
+ - buffer_size: ํ ๋ฒ์ ์ฒ๋ฆฌํ ๋ฒกํฐ์ ์ต๋ ๊ฐ์ (๊ธฐ๋ณธ๊ฐ์ 50000).
+ - index_id_to_db_id: FAISS์์ ์ธ๋ฑ์ค ID์ ์ค์ ๋ฐ์ดํฐ๋ฒ ์ด์ค ID๋ฅผ ๋งคํํ๊ธฐ ์ํ ๋ฆฌ์คํธ.
+ - index: FAISS ์ธ๋ฑ์ค ๊ฐ์ฒด.
+ """
+ self.buffer_size = buffer_size
+ self.index_id_to_db_id = [] # ์ธ๋ฑ์ค ID๋ฅผ ์ค์ ๋ฐ์ดํฐ๋ฒ ์ด์ค ID์ ๋งคํํ๋ ๋ฆฌ์คํธ.
+ self.index = None # FAISS ์ธ๋ฑ์ค ๊ฐ์ฒด.
+
+ def init_index(self, vector_sz: int):
+ raise NotImplementedError
+
+ def index_data(self, data: List[Tuple[object, np.array]]):
+ raise NotImplementedError
+
+ def get_index_name(self):
+ raise NotImplementedError
+
+ def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]:
+ raise NotImplementedError
+
+ def serialize(self, file: str):
+ logger.info("Serializing index to %s", file)
+
+ if os.path.isdir(file):
+ index_file = os.path.join(file, "index.dpr")
+ meta_file = os.path.join(file, "index_meta.dpr")
+ else:
+ index_file = file + ".index.dpr"
+ meta_file = file + ".index_meta.dpr"
+
+ faiss.write_index(self.index, index_file) # FAISS ์ธ๋ฑ์ค๋ฅผ ํ์ผ์ ์ ์ฅ.
+ with open(meta_file, mode="wb") as f:
+ pickle.dump(self.index_id_to_db_id, f) # ID ๋งคํ ์ ๋ณด๋ฅผ ์ ์ฅ.
+
+ def get_files(self, path: str):
+ if os.path.isdir(path):
+ index_file = os.path.join(path, "index.dpr") # FAISS ์ธ๋ฑ์ค๋ฅผ ํ์ผ์์ ๋ก๋.
+ meta_file = os.path.join(path, "index_meta.dpr")
+ else:
+ index_file = path + ".{}.dpr".format(self.get_index_name())
+ meta_file = path + ".{}_meta.dpr".format(self.get_index_name())
+ return index_file, meta_file
+
+ def index_exists(self, path: str):
+ index_file, meta_file = self.get_files(path)
+ return os.path.isfile(index_file) and os.path.isfile(meta_file)
+
+ def deserialize(self, path: str):
+ logger.info("Loading index from %s", path)
+ index_file, meta_file = self.get_files(path)
+
+ self.index = faiss.read_index(index_file)
+ logger.info("Loaded index of type %s and size %d", type(self.index), self.index.ntotal)
+
+ with open(meta_file, "rb") as reader:
+ self.index_id_to_db_id = pickle.load(reader)
+ assert (
+ len(self.index_id_to_db_id) == self.index.ntotal
+ ), "Deserialized index_id_to_db_id should match faiss index size"
+
+ def _update_id_mapping(self, db_ids: List) -> int:
+ self.index_id_to_db_id.extend(db_ids)
+ return len(self.index_id_to_db_id)
+
+
+class DenseFlatIndexer(DenseIndexer):
+ def __init__(self, buffer_size: int = 50000):
+ super(DenseFlatIndexer, self).__init__(buffer_size=buffer_size)
+
+ def init_index(self, vector_sz: int):
+ self.index = faiss.IndexFlatIP(vector_sz) # Inner Product๋ฅผ ์ฌ์ฉํ๋ ๊ธฐ๋ณธ ์ธ๋ฑ์ค ์ด๊ธฐํ.
+
+ def index_data(self, data: List[Tuple[object, np.array]]):
+ n = len(data)
+ # indexing in batches is beneficial for many faiss index types
+ for i in range(0, n, self.buffer_size):
+ db_ids = [t[0] for t in data[i : i + self.buffer_size]]
+ vectors = [np.reshape(t[1], (1, -1)) for t in data[i : i + self.buffer_size]]
+ vectors = np.concatenate(vectors, axis=0)
+ total_data = self._update_id_mapping(db_ids)
+ self.index.add(vectors)
+ logger.info("data indexed %d", total_data)
+
+ indexed_cnt = len(self.index_id_to_db_id)
+ logger.info("Total data indexed %d", indexed_cnt)
+
+ def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]:
+ scores, indexes = self.index.search(query_vectors, top_docs) # ์ฟผ๋ฆฌ ๋ฒกํฐ์ ๊ฐ์ฅ ์ ์ฌํ ๋ฒกํฐ ๊ฒ์.
+ # convert to external ids
+ db_ids = [[self.index_id_to_db_id[i] for i in query_top_idxs] for query_top_idxs in indexes]
+ result = [(db_ids[i], scores[i]) for i in range(len(db_ids))]
+ return result
+
+ def get_index_name(self):
+ return "flat_index"
+
+
+class DenseHNSWFlatIndexer(DenseIndexer):
+ """
+ Efficient index for retrieval. Note: default settings are for hugh accuracy but also high RAM usage
+ """
+
+ def __init__(
+ self,
+ buffer_size: int = 1e9,
+ store_n: int = 512,
+ ef_search: int = 128,
+ ef_construction: int = 200,
+ ):
+ super(DenseHNSWFlatIndexer, self).__init__(buffer_size=buffer_size)
+ self.store_n = store_n
+ self.ef_search = ef_search
+ self.ef_construction = ef_construction
+ self.phi = 0
+
+ def init_index(self, vector_sz: int):
+ # IndexHNSWFlat supports L2 similarity only
+ # so we have to apply DOT -> L2 similairy space conversion with the help of an extra dimension
+ index = faiss.IndexHNSWFlat(vector_sz + 1, self.store_n) # L2 ๊ฑฐ๋ฆฌ ๊ธฐ๋ฐ HNSW ์ธ๋ฑ์ค ์ด๊ธฐํ.
+ index.hnsw.efSearch = self.ef_search
+ index.hnsw.efConstruction = self.ef_construction
+ self.index = index
+
+ def index_data(self, data: List[Tuple[object, np.array]]):
+ n = len(data)
+
+ # max norm is required before putting all vectors in the index to convert inner product similarity to L2
+ if self.phi > 0:
+ raise RuntimeError(
+ "DPR HNSWF index needs to index all data at once," "results will be unpredictable otherwise."
+ )
+ phi = 0
+ for i, item in enumerate(data):
+ id, doc_vector = item[0:2]
+ norms = (doc_vector**2).sum()
+ phi = max(phi, norms)
+ logger.info("HNSWF DotProduct -> L2 space phi={}".format(phi))
+ self.phi = phi
+
+ # indexing in batches is beneficial for many faiss index types
+ bs = int(self.buffer_size)
+ for i in range(0, n, bs):
+ db_ids = [t[0] for t in data[i : i + bs]]
+ vectors = [np.reshape(t[1], (1, -1)) for t in data[i : i + bs]]
+
+ norms = [(doc_vector**2).sum() for doc_vector in vectors]
+ aux_dims = [np.sqrt(phi - norm) for norm in norms]
+ hnsw_vectors = [np.hstack((doc_vector, aux_dims[i].reshape(-1, 1))) for i, doc_vector in enumerate(vectors)]
+ hnsw_vectors = np.concatenate(hnsw_vectors, axis=0)
+ self.train(hnsw_vectors)
+
+ self._update_id_mapping(db_ids)
+ self.index.add(hnsw_vectors)
+ logger.info("data indexed %d", len(self.index_id_to_db_id))
+ indexed_cnt = len(self.index_id_to_db_id)
+ logger.info("Total data indexed %d", indexed_cnt)
+
+ def train(self, vectors: np.array):
+ pass
+
+ def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]:
+ aux_dim = np.zeros(len(query_vectors), dtype="float32")
+ query_nhsw_vectors = np.hstack((query_vectors, aux_dim.reshape(-1, 1)))
+ logger.info("query_hnsw_vectors %s", query_nhsw_vectors.shape)
+ scores, indexes = self.index.search(query_nhsw_vectors, top_docs)
+ # convert to external ids
+ db_ids = [[self.index_id_to_db_id[i] for i in query_top_idxs] for query_top_idxs in indexes]
+ result = [(db_ids[i], scores[i]) for i in range(len(db_ids))]
+ return result
+
+ def deserialize(self, file: str):
+ super(DenseHNSWFlatIndexer, self).deserialize(file)
+ # to trigger exception on subsequent indexing
+ self.phi = 1
+
+ def get_index_name(self):
+ return "hnsw_index"
+
+
+class DenseHNSWSQIndexer(DenseHNSWFlatIndexer):
+ """
+ Efficient index for retrieval. Note: default settings are for hugh accuracy but also high RAM usage
+ """
+
+ def __init__(
+ self,
+ buffer_size: int = 1e10,
+ store_n: int = 128,
+ ef_search: int = 128,
+ ef_construction: int = 200,
+ ):
+ super(DenseHNSWSQIndexer, self).__init__(
+ buffer_size=buffer_size,
+ store_n=store_n,
+ ef_search=ef_search,
+ ef_construction=ef_construction,
+ )
+
+ def init_index(self, vector_sz: int):
+ # IndexHNSWFlat supports L2 similarity only
+ # so we have to apply DOT -> L2 similairy space conversion with the help of an extra dimension
+ index = faiss.IndexHNSWSQ(vector_sz + 1, faiss.ScalarQuantizer.QT_8bit, self.store_n)
+ index.hnsw.efSearch = self.ef_search
+ index.hnsw.efConstruction = self.ef_construction
+ self.index = index
+
+ def train(self, vectors: np.array):
+ self.index.train(vectors)
+
+ def get_index_name(self):
+ return "hnswsq_index"
diff --git a/code/rag/prepare_dense.py b/code/rag/prepare_dense.py
new file mode 100644
index 0000000..3f81dfe
--- /dev/null
+++ b/code/rag/prepare_dense.py
@@ -0,0 +1,135 @@
+import logging
+import os
+
+from chunk_data import save_orig_passage, save_title_index_map
+from dpr_data import KorQuadDataset
+from encoder import KobertBiEncoder
+
+# ์ธ๋ถ ์คํฌ๋ฆฝํธ์์ IndexRunner ์ํฌํธ
+from index_runner import IndexRunner
+from indexers import DenseFlatIndexer # index ๊ด๋ จ
+from retriever import KorDPRRetriever # retriever.py์์ ๊ฐ์ ธ์ค๊ธฐ
+import torch
+from trainer import Trainer
+import transformers
+
+
+transformers.logging.set_verbosity_error() # ํ ํฌ๋์ด์ ์ด๊ธฐํ ๊ด๋ จ ๊ฒฝ๊ณ ์ต์
+
+# ๋ก๊น
์ค์
+os.makedirs("logs", exist_ok=True)
+logging.basicConfig(
+ filename="logs/log.log",
+ level=logging.DEBUG,
+ format="[%(asctime)s | %(funcName)s @ %(pathname)s] %(message)s",
+)
+logger = logging.getLogger()
+
+
+# ๋ชจ๋ธ ์กด์ฌ ์ฌ๋ถ ํ์ธ ํจ์
+def check_if_model_exists(model_path: str):
+ """๋ชจ๋ธ ์ฒดํฌํฌ์ธํธ๊ฐ ์กด์ฌํ๋์ง ํ์ธํ๋ ํจ์"""
+ return os.path.exists(model_path)
+
+
+# ์ํค ๋ฐ์ดํฐ ์ฒ๋ฆฌ ํจ์
+def process_wiki_data():
+ """
+ ์ํค ๋ฐ์ดํฐ๋ฅผ ์ฒ๋ฆฌํ๊ณ ํ์ํ ํผํด ํ์ผ์ ์์ฑํฉ๋๋ค.
+ """
+ processed_passages_path = "processed_passages"
+
+ if not os.path.exists(processed_passages_path):
+ logger.debug(f"'{processed_passages_path}' ํด๋๊ฐ ์กด์ฌํ์ง ์์ต๋๋ค. ๋ฐ์ดํฐ ์ฒ๋ฆฌ๋ฅผ ์์ํฉ๋๋ค.")
+
+ # 1. chunk ๋ฐ์ดํฐ๋ฅผ ํผํด ํ์ผ๋ก ๋ณํํ์ฌ processed_passages ํด๋์ ์ ์ฅ (์ฝ 10๋ถ ์์)
+ save_orig_passage()
+
+ # 2. ์ ๋ชฉ๊ณผ ์ธ๋ฑ์ค ๋งคํ ์ ์ฅ
+ save_title_index_map()
+
+ logger.debug("๋ฐ์ดํฐ ์ฒ๋ฆฌ๊ฐ ์๋ฃ๋์์ต๋๋ค.")
+ else:
+ logger.debug(f"'{processed_passages_path}' ํด๋๊ฐ ์ด๋ฏธ ์กด์ฌํฉ๋๋ค. ๋ฐ์ดํฐ ์ฒ๋ฆฌ๋ฅผ ๊ฑด๋๋๋๋ค.")
+
+
+if __name__ == "__main__":
+ # ์ํค ๋ฐ์ดํฐ ์ฒ๋ฆฌ
+ # processed_passage ํด๋ ๋ด์ ํผํดํ๋ ๋ฐ์ดํฐ๊ฐ ์ ์ฅ๋ฉ๋๋ค. 10๋ถ ์์
+ process_wiki_data()
+
+ # ๋ชจ๋ธ ๊ฒฝ๋ก ์ค์
+ model_path = "./output/my_model.pt"
+
+ # korquad ๋ฐ์ดํฐ๋ก ๋ชจ๋ธ์ ํ์ต์์ผ์ค๋๋ค.
+ # ๋ชจ๋ธ์ด ์ด๋ฏธ ์กด์ฌํ๋ฉด ํ์ต์ ๊ฑด๋๋๋๋ค
+ if check_if_model_exists(model_path):
+ logger.debug(f"์ด๋ฏธ ํ์ต๋ ๋ชจ๋ธ์ด {model_path}์ ์กด์ฌํฉ๋๋ค. ํ์ต์ ๊ฑด๋๋๋๋ค.")
+ else:
+ logger.debug("ํ์ต๋ ๋ชจ๋ธ์ด ์์ต๋๋ค. ํ์ต์ ์์ํฉ๋๋ค.")
+
+ # ๋ชจ๋ธ๊ณผ ๋ฐ์ดํฐ์
์ค๋น
+ device = torch.device("cuda:0")
+ model = KobertBiEncoder()
+ train_dataset = KorQuadDataset("./data/KorQuAD_v1.0_train.json")
+ valid_dataset = KorQuadDataset("./data/KorQuAD_v1.0_dev.json")
+
+ # Trainer ๊ฐ์ฒด ์์ฑ ๋ฐ ํ์ต ์์
+ my_trainer = Trainer(
+ model=model,
+ device=device,
+ train_dataset=train_dataset,
+ valid_dataset=valid_dataset,
+ num_epoch=10,
+ batch_size=128 - 32,
+ lr=1e-5,
+ betas=(0.9, 0.99),
+ num_warmup_steps=1000,
+ num_training_steps=100000,
+ valid_every=30,
+ best_val_ckpt_path=model_path,
+ )
+
+ # ํ์ต ์ํ ๋ถ๋ฌ์ค๊ธฐ
+ # my_trainer.load_training_state()
+
+ # ํ์ต ์์
+ my_trainer.fit()
+
+ # Indexing์ ์คํํ๋ ์ฝ๋ (IndexRunner ์ฌ์ฉ)
+ index_path = "./2050iter_flat" # ์ธ๋ฑ์ค ํ์ผ ๊ฒฝ๋ก ์ค์
+
+ # ์ธ๋ฑ์ค๊ฐ ์ด๋ฏธ ์กด์ฌํ๋ฉด ์ธ๋ฑ์ฑ์ ๊ฑด๋๋๋๋ค
+ if not os.path.exists(index_path):
+ logger.info("์ธ๋ฑ์ค๊ฐ ์กด์ฌํ์ง ์์ต๋๋ค. ์ธ๋ฑ์ฑ์ ์์ํฉ๋๋ค.")
+ index_runner = IndexRunner(
+ data_dir="./text",
+ model_ckpt_path="./output/my_model.pt",
+ index_output=index_path,
+ )
+ index_runner.run()
+ else:
+ logger.info(f"์ธ๋ฑ์ค ํ์ผ '{index_path}'๊ฐ ์ด๋ฏธ ์กด์ฌํฉ๋๋ค. ์ธ๋ฑ์ฑ์ ๊ฑด๋๋๋๋ค.")
+
+ # index ํ์ผ ๋ก๋ฉ
+ index = DenseFlatIndexer()
+ index.deserialize(path=index_path) # ์ด๋ฏธ ์์ฑ๋ ์ธ๋ฑ์ค ํ์ผ์ ๋ก๋
+
+ # retriever.py๋ก๋ถํฐ KorDPRRetriever ๊ฐ์ฒด๋ฅผ ์์ฑํ์ฌ ์ฟผ๋ฆฌ ์คํ
+ model = KobertBiEncoder()
+ model.load("./output/my_model.pt")
+ model.eval()
+
+ valid_dataset = KorQuadDataset("./data/KorQuAD_v1.0_dev.json")
+ retriever = KorDPRRetriever(model=model, valid_dataset=valid_dataset, index=index)
+
+ # 'query'์ 'k' ๊ฐ์ ์ค์ ํฉ๋๋ค.
+ query = "์ค๊ตญ์ ์ฒ์๋ฌธ ์ฌํ๊ฐ ์ผ์ด๋ ๋
๋๋?"
+ k = 10 # ์์ 10๊ฐ ์ ์ฌํ passage๋ฅผ ์ถ๋ ฅ
+
+ # retrieve ๋ฉ์๋๋ฅผ ํธ์ถํ์ฌ ๊ฐ์ฅ ์ ์ฌ๋๊ฐ ๋์ k๊ฐ์ passage๋ฅผ ์ฐพ์ต๋๋ค.
+ passages = retriever.retrieve(query=query, k=k)
+
+ # ์ถ๋ ฅ: ์ ์ฌ๋ ๋์ passage์ ๊ทธ ์ ์ฌ๋๋ฅผ ์ถ๋ ฅํฉ๋๋ค.
+ for idx, (passage, sim) in enumerate(passages):
+ logger.debug(f"Rank {idx + 1} | Similarity: {sim:.4f} | Passage: {passage}")
diff --git a/code/rag/reranker.py b/code/rag/reranker.py
new file mode 100644
index 0000000..2716a23
--- /dev/null
+++ b/code/rag/reranker.py
@@ -0,0 +1,105 @@
+import gc
+import os
+from typing import Dict, List
+
+from dotenv import load_dotenv
+from loguru import logger
+import numpy as np
+import torch
+from tqdm import tqdm
+from transformers import AutoModelForSequenceClassification, AutoTokenizer
+
+from .retriever_elastic import ElasticsearchRetriever
+
+
+class Reranker:
+ def __init__(
+ self,
+ model_path: str = "Dongjin-kr/ko-reranker",
+ batch_size: int = 128,
+ max_length: int = 512,
+ ):
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
+ self.model.to(self.device)
+ self.model.eval()
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
+ self.batch_size = batch_size
+ self.max_length = max_length
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ # GPU ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ
+ del self.model
+ del self.tokenizer
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ def _exp_normalize(self, x):
+ y = np.exp(x - x.max(axis=1, keepdims=True))
+ return y / y.sum(axis=1, keepdims=True)
+
+ def rerank(self, queries: List[str], retrieve_results: List[List[Dict]], topk: int = 5) -> List[List[Dict]]:
+ # ์
๋ ฅ ๋ฐ์ดํฐ ์ค๋น
+ all_pairs = []
+ for query, results in zip(queries, retrieve_results):
+ for result in results:
+ all_pairs.append([query, result["text"]])
+
+ # ๋ฐฐ์น ์ฒ๋ฆฌ
+ all_scores = []
+ for i in tqdm(range(0, len(all_pairs), self.batch_size), desc="Reranking"):
+ batch_pairs = all_pairs[i : i + self.batch_size]
+
+ with torch.no_grad():
+ inputs = self.tokenizer(
+ batch_pairs,
+ padding=True,
+ truncation=True,
+ return_tensors="pt",
+ max_length=self.max_length,
+ )
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
+ batch_scores = self.model(**inputs, return_dict=True).logits.view(-1).float().cpu().numpy()
+ all_scores.extend(batch_scores)
+
+ all_scores = np.array(all_scores)
+
+ reranked_results = []
+ start = 0
+ for results in retrieve_results:
+ end = start + len(results)
+ scores = all_scores[start:end]
+ scores = self._exp_normalize(scores.reshape(1, -1)).flatten()
+ top_indices = np.argsort(scores)[-topk:][::-1]
+
+ reranked_batch = [{"text": results[i]["text"], "score": float(scores[i])} for i in top_indices]
+ reranked_results.append(reranked_batch)
+ start = end
+
+ return reranked_results
+
+
+if __name__ == "__main__":
+ config_folder = os.path.join(os.path.dirname(__file__), "..", "..", "config")
+ load_dotenv(os.path.join(config_folder, ".env"))
+
+ reranker = Reranker(
+ model_path="Dongjin-kr/ko-reranker",
+ batch_size=128,
+ max_length=512,
+ )
+ retriever = ElasticsearchRetriever(
+ index_name="two-wiki-index",
+ )
+
+ query = "์ ๋น๋ค ์๋ง ๋ช
์ด ๋๊ถ ์์ ๋ชจ์ฌ ๋ง ๋๋ฌ์ ์์์ ๋ค์ ์ค๋ฆฝํ ๊ฒ์ ์ฒญํ๋, (๊ฐ)์ด/๊ฐ ํฌ๊ฒ ๋
ธํ์ฌ ํ์ฑ๋ถ์ ์กฐ๋ก(็้ท)์ ๋ณ์กธ๋ก ํ์ฌ ๊ธ ํ ๊ฐ ๋ฐ์ผ๋ก ๋ชฐ์๋ด๊ฒ ํ๊ณ ๋๋์ด ์ฒ์ฌ ๊ณณ์ ์์์ ์ฒ ํํ๊ณ ๊ทธ ํ ์ง๋ฅผ ๋ชฐ์ํ์ฌ ๊ด์ ์ํ๊ฒ ํ์๋ค.๏ผ๋ํ๊ณ๋
์ฌ" # noqa: E501
+ retriever_result = retriever.retrieve(query, top_k=5)
+ logger.debug("Elasticsearch Retriever")
+ logger.debug(f"{retriever_result[:5]}")
+
+ reranked_results = reranker.rerank(queries=[query], retrieve_results=[retriever_result], topk=3)
+ logger.debug("Reranker")
+ logger.debug(f"{reranked_results}")
diff --git a/code/rag/retriever.py b/code/rag/retriever.py
new file mode 100644
index 0000000..b32d5b6
--- /dev/null
+++ b/code/rag/retriever.py
@@ -0,0 +1,187 @@
+from collections import defaultdict
+from glob import glob
+import math
+import os
+import pickle
+import typing
+
+# from utils import get_passage_file
+from typing import List
+
+from dpr_data import KorQuadDataset, KorQuadSampler, korquad_collator
+from encoder import KobertBiEncoder
+from indexers import DenseFlatIndexer
+from loguru import logger
+import torch
+from torch import tensor as T
+from tqdm import tqdm
+
+
+def get_wiki_filepath(data_dir):
+ return glob(f"{data_dir}/*/wiki_*")
+
+
+def wiki_worker_init(worker_id):
+ worker_info = torch.utils.data.get_worker_info()
+ dataset = worker_info.dataset
+ # logger.debug(dataset)
+ # dataset =
+ overall_start = dataset.start
+ overall_end = dataset.end
+ split_size = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
+ worker_id = worker_info.id
+ # end_idx = min((worker_id+1) * split_size, len(dataset.data))
+ dataset.start = overall_start + worker_id * split_size
+ dataset.end = min(dataset.start + split_size, overall_end) # index error ๋ฐฉ์ง
+
+
+def get_passage_file(p_id_list: typing.List[int]) -> str:
+ """passage id๋ฅผ ๋ฐ์์ ํด๋น๋๋ ํ์ผ ์ด๋ฆ์ ๋ฐํํฉ๋๋ค."""
+ target_file = None
+ p_id_max = max(p_id_list)
+ p_id_min = min(p_id_list)
+
+ # ํ์ฌ ํ์ผ์ ๊ฒฝ๋ก๋ฅผ ๊ธฐ์ค์ผ๋ก 'processed_passages' ๊ฒฝ๋ก ์ค์
+ current_dir = os.path.dirname(os.path.abspath(__file__))
+ passages_dir = os.path.join(current_dir, "processed_passages")
+
+ # 'processed_passages' ๋๋ ํฐ๋ฆฌ์์ ํ์ผ์ ์ฐพ์
+ for f in glob(f"{passages_dir}/*.p"):
+ file_name = os.path.basename(f)
+ s, e = file_name.split(".")[0].split("-")
+ s, e = int(s), int(e)
+
+ if p_id_min >= s and p_id_max <= e:
+ target_file = f
+ break
+
+ if target_file is None:
+ logger.debug(f"No file found for passage IDs: {p_id_list}")
+
+ return target_file
+
+
+class KorDPRRetriever:
+ def __init__(self, model, valid_dataset, index, val_batch_size: int = 64, device="cuda:0"):
+ # ๋ชจ๋ธ์ด ๊ฒฝ๋ก๋ก ์ฃผ์ด์ง ๊ฒฝ์ฐ ๋ก๋
+ if isinstance(model, str):
+ self.model = KobertBiEncoder()
+ self.model.load(model)
+ else:
+ self.model = model
+
+ # ๋ชจ๋ธ์ ํด๋น ๋๋ฐ์ด์ค๋ก ์ด๋
+ self.model = self.model.to(device)
+ self.model.eval()
+
+ # ๋ฐ์ดํฐ์
๋ก๋
+ self.valid_dataset = valid_dataset
+
+ # ์ธ๋ฑ์ค๊ฐ ๊ฒฝ๋ก๋ก ์ฃผ์ด์ง ๊ฒฝ์ฐ ๋ก๋
+ if isinstance(index, str):
+ self.index = DenseFlatIndexer()
+ self.index.deserialize(path=index)
+ else:
+ self.index = index
+ self.model = model.to(device)
+ self.device = device
+ self.tokenizer = valid_dataset.tokenizer
+ self.val_batch_size = val_batch_size
+ self.valid_loader = torch.utils.data.DataLoader(
+ dataset=valid_dataset.dataset,
+ batch_sampler=KorQuadSampler(valid_dataset.dataset, batch_size=val_batch_size, drop_last=False),
+ collate_fn=lambda x: korquad_collator(x, padding_value=valid_dataset.pad_token_id),
+ num_workers=4,
+ )
+ self.index = index
+
+ def val_top_k_acc(self, k: List[int] = [5] + list(range(10, 101, 10))):
+ """validation set์์ top k ์ ํ๋๋ฅผ ๊ณ์ฐํฉ๋๋ค."""
+
+ self.model.eval() # ํ๊ฐ ๋ชจ๋
+ k_max = max(k)
+ sample_cnt = 0
+ retr_cnt = defaultdict(int)
+ with torch.no_grad():
+ for batch in tqdm(self.valid_loader, desc="valid"):
+ # batch_q, batch_q_attn_mask, batch_p_id, batch_p, batch_p_attn_mask
+ q, q_mask, p_id, a, a_mask = batch
+ q, q_mask = (
+ q.to(self.device),
+ q_mask.to(self.device),
+ )
+ q_emb = self.model(q, q_mask, "query") # bsz x bert_dim
+ result = self.index.search_knn(query_vectors=q_emb.cpu().numpy(), top_docs=k_max)
+
+ for (pred_idx_lst, _), true_idx, _a, _a_mask in zip(result, p_id, a, a_mask):
+ a_len = _a_mask.sum()
+ _a = _a[:a_len]
+ _a = _a[1:-1]
+ _a_txt = self.tokenizer.decode(_a).strip()
+ docs = [pickle.load(open(get_passage_file([idx]), "rb"))[idx] for idx in pred_idx_lst]
+
+ for _k in k:
+ if _a_txt in " ".join(docs[:_k]):
+ retr_cnt[_k] += 1
+
+ bsz = q.size(0)
+ sample_cnt += bsz
+ retr_acc = {_k: float(v) / float(sample_cnt) for _k, v in retr_cnt.items()}
+ return retr_acc
+
+ def retrieve(self, query: str, k: int = 100):
+ """์ฃผ์ด์ง ์ฟผ๋ฆฌ์ ๋ํด ๊ฐ์ฅ ์ ์ฌ๋๊ฐ ๋์ passage๋ฅผ ๋ฐํํฉ๋๋ค."""
+ self.model.eval() # ํ๊ฐ ๋ชจ๋
+ tok = self.tokenizer.batch_encode_plus([query], truncation=True, padding=True, max_length=512)
+
+ # Ensure tensors are moved to the same device as the model (cuda:0)
+ input_ids = T(tok["input_ids"]).to(self.device)
+ attention_mask = T(tok["attention_mask"]).to(self.device)
+
+ with torch.no_grad():
+ out = self.model(input_ids, attention_mask, "query")
+
+ result = self.index.search_knn(query_vectors=out.cpu().numpy(), top_docs=k)
+ # logger.debug(result)
+ # ์๋ฌธ ๊ฐ์ ธ์ค๊ธฐ
+ passages = []
+ for idx, sim in zip(*result[0]):
+ # logger.debug(idx)
+ path = get_passage_file([idx])
+ if not path:
+ logger.debug(f"์ฌ๋ฐ๋ฅธ ๊ฒฝ๋ก์ ํผํดํ๋ ์ํคํผ๋์๊ฐ ์๋์ง ํ์ธํ์ธ์.No single passage path for {idx}")
+ continue
+ with open(path, "rb") as f:
+ passage_dict = pickle.load(f)
+ # logger.debug(f"passage : {passage_dict[idx]}, sim : {sim}")
+ passages.append((passage_dict[idx], sim))
+ # logger.debug("์ฑ๊ณต!!!!!!")
+ return passages
+
+
+if __name__ == "__main__":
+ # parser = argparse.ArgumentParser()
+ # parser.add_argument("--query", "-q", type=str, required=True)
+ # parser.add_argument("--k", "-k", type=int, required=True)
+ # args = parser.parse_args()
+
+ model = KobertBiEncoder()
+ model.load("./output/my_model.pt")
+ model.eval()
+ valid_dataset = KorQuadDataset("./data/KorQuAD_v1.0_dev.json")
+ index = DenseFlatIndexer()
+ index.deserialize(path="./2050iter_flat")
+
+ retriever = KorDPRRetriever(model=model, valid_dataset=valid_dataset, index=index)
+
+ # 'query'์ 'k' ๊ฐ์ ์ค์ ํฉ๋๋ค.
+ query = "(๊ฐ)์ด/๊ฐ ํฌ๊ฒ ๋
ธํ์ฌ ํ์ฑ๋ถ์ ์กฐ๋ก(็้ท)์ ๋ณ์กธ๋ก ํ์ฌ ๊ธ ํ ๊ฐ ๋ฐ์ผ๋ก ๋ชฐ์๋ด๊ฒ ํ๊ณ ๋๋์ด ์ฒ์ฌ ๊ณณ์ ์์์ ์ฒ ํํ๊ณ ๊ทธ ํ ์ง๋ฅผ ๋ชฐ์ํ์ฌ ๊ด์ ์ํ๊ฒ ํ์๋ค.๏ผ๋ํ๊ณ๋
์ฌ ๏ผ" # noqa: E501
+ # query = "์ฑํ์ง์์ ์ ์๋?"
+ k = 10 # ์์ 20๊ฐ ์ ์ฌํ passage๋ฅผ ์ถ๋ ฅํ๋ ค๋ฉด k๋ฅผ 20์ผ๋ก ์ค์
+
+ # retrieve ๋ฉ์๋๋ฅผ ํธ์ถํ์ฌ ๊ฐ์ฅ ์ ์ฌ๋๊ฐ ๋์ k๊ฐ์ passage๋ฅผ ์ฐพ์ต๋๋ค.
+ passages = retriever.retrieve(query=query, k=k)
+
+ # ์ถ๋ ฅ: ์ ์ฌ๋ ๋์ passage์ ๊ทธ ์ ์ฌ๋๋ฅผ ์ถ๋ ฅํฉ๋๋ค.
+ for idx, (passage, sim) in enumerate(passages):
+ logger.debug(f"Rank {idx + 1} | Similarity: {sim:.4f} | Passage: {passage}")
diff --git a/code/rag/retriever_bm25.py b/code/rag/retriever_bm25.py
new file mode 100644
index 0000000..532fcb9
--- /dev/null
+++ b/code/rag/retriever_bm25.py
@@ -0,0 +1,140 @@
+import json
+import os
+import pickle
+from typing import Dict, List, Optional
+
+from loguru import logger
+import numpy as np
+from rank_bm25 import BM25Okapi
+
+
+# from konlpy.tag import Okt
+# okt = Okt()
+# def okt_specific_pos_tokenizer(text, stem=True, norm=True):
+# # pos ํ๊น
์ํ
+# pos_tagged = okt.pos(text, stem=stem, norm=norm)
+# # ๋ช
์ฌ(Noun), ํ์ฉ์ฌ(Adjective), ๋์ฌ(Verb)๋ง ํํฐ๋ง
+# filtered_words = [word for word, pos in pos_tagged if pos in ["Noun", "Adjective", "Verb"]]
+# return filtered_words
+
+
+# Deprecated: ๋๋ฌด ๋๋ ค์ ๋ ์ด์ ์ฌ์ฉํ์ง ์์ต๋๋ค.
+class BM25Retriever:
+ def __init__(
+ self,
+ tokenize_fn=None,
+ data_path: Optional[str] = "../data/",
+ pickle_filename: str = "wiki_bm25.pkl",
+ doc_filename: Optional[str] = "wiki_document.json",
+ ) -> None:
+ self.tokenize_fn = tokenize_fn if tokenize_fn else lambda x: x.split()
+ self.pickle_path = os.path.join(data_path, pickle_filename)
+ self.bm25 = None
+ self.corpus = []
+
+ # ๋ฐ์ดํฐ์
๋ก๋
+ self._load_dataset(os.path.join(data_path, doc_filename))
+
+ # ๊ธฐ์กด ์ธ๋ฑ์ค ๋ก๋
+ if os.path.exists(self.pickle_path):
+ self._load_pickle()
+ return
+
+ # ์ธ๋ฑ์ค ์์ฑ
+ self._initialize_retriever()
+
+ def _load_dataset(self, json_path):
+ logger.info("๋ฌธ์ ๋ฐ์ดํฐ์
๋ก๋")
+ with open(json_path, "r", encoding="utf-8") as f:
+ docs = json.load(f)
+ self.corpus = [f"{doc['title']}: {doc['text']}" for doc in docs]
+
+ def _load_pickle(self):
+ logger.info("๊ธฐ์กด BM25 ์ธ๋ฑ์ค ๋ก๋")
+ with open(self.pickle_path, "rb") as f:
+ data = pickle.load(f)
+ self.bm25 = data["bm25"]
+
+ def _initialize_retriever(self):
+ logger.info("์๋ก์ด BM25 ์ธ๋ฑ์ค ์์ฑ")
+
+ tokenized_corpus = [self.tokenize_fn(doc) for doc in self.corpus]
+ self.bm25 = BM25Okapi(tokenized_corpus)
+
+ with open(self.pickle_path, "wb") as f:
+ pickle.dump(
+ {
+ "bm25": self.bm25,
+ },
+ f,
+ )
+ logger.info("์ธ๋ฑ์ค ์์ฑ ์๋ฃ")
+
+ def retrieve(self, query: str, top_k: int = 3) -> List[Dict]:
+ """
+ ์ฃผ์ด์ง ์ฟผ๋ฆฌ์ ๋ํด ์์ k๊ฐ์ ๋ฌธ์๋ฅผ ๊ฒ์ํฉ๋๋ค.
+ """
+ if not self.bm25:
+ raise Exception("BM25 ๋ชจ๋ธ์ด ์ด๊ธฐํ๋์ง ์์์ต๋๋ค.")
+
+ tokenized_query = self.tokenize_fn(query)
+ doc_scores = self.bm25.get_scores(tokenized_query)
+ top_indices = np.argsort(doc_scores)[-top_k:][::-1]
+
+ results = []
+ for idx in top_indices:
+ results.append(
+ {
+ "text": self.corpus[idx],
+ "score": float(doc_scores[idx]),
+ }
+ )
+ return results
+
+ def bulk_retrieve(self, queries: List[str], top_k: int = 3) -> List[List[Dict]]:
+ """
+ ์ฌ๋ฌ ์ฟผ๋ฆฌ์ ๋ํด ์ผ๊ด์ ์ผ๋ก ๊ฒ์์ ์ํํฉ๋๋ค.
+ """
+ if not self.bm25:
+ raise Exception("BM25 ๋ชจ๋ธ์ด ์ด๊ธฐํ๋์ง ์์์ต๋๋ค.")
+
+ results = []
+ logger.info(f"{len(queries)}๊ฐ ์ฟผ๋ฆฌ ์ผ๊ด ๊ฒ์")
+ # ๋ชจ๋ ์ฟผ๋ฆฌ๋ฅผ ํ ๋ฒ์ ํ ํฌ๋์ด์ง
+ tokenized_queries = [self.tokenize_fn(query) for query in queries]
+
+ # ๊ฐ ์ฟผ๋ฆฌ๋ณ๋ก ๊ฒ์ ์ํ
+ for tokenized_query in tokenized_queries:
+ doc_scores = self.bm25.get_scores(tokenized_query)
+ top_indices = np.argsort(doc_scores)[-top_k:][::-1]
+
+ query_results = []
+ for idx in top_indices:
+ query_results.append(
+ {
+ "text": self.corpus[idx],
+ "score": float(doc_scores[idx]),
+ }
+ )
+ results.append(query_results)
+
+ logger.info(f"{len(queries)}๊ฐ ์ฟผ๋ฆฌ ์ผ๊ด ๊ฒ์ ์๋ฃ")
+ return results
+
+
+if __name__ == "__main__":
+ os.chdir("..")
+ retriever = BM25Retriever(
+ tokenize_fn=None,
+ data_path="../data/",
+ pickle_filename="wiki_bm25.pkl",
+ doc_filename="wiki.json",
+ )
+
+ query = "์ ๋น๋ค ์๋ง ๋ช
์ด ๋๊ถ ์์ ๋ชจ์ฌ ๋ง ๋๋ฌ์ ์์์ ๋ค์ ์ค๋ฆฝํ ๊ฒ์ ์ฒญํ๋, (๊ฐ)์ด/๊ฐ ํฌ๊ฒ ๋
ธํ์ฌ ํ์ฑ๋ถ์ ์กฐ๋ก(็้ท)์ ๋ณ์กธ๋ก ํ์ฌ ๊ธ ํ ๊ฐ ๋ฐ์ผ๋ก ๋ชฐ์๋ด๊ฒ ํ๊ณ ๋๋์ด ์ฒ์ฌ ๊ณณ์ ์์์ ์ฒ ํํ๊ณ ๊ทธ ํ ์ง๋ฅผ ๋ชฐ์ํ์ฌ ๊ด์ ์ํ๊ฒ ํ์๋ค.๏ผ๋ํ๊ณ๋
์ฌ" # noqa: E501
+ results = retriever.retrieve(query, top_k=5)
+
+ for i, result in enumerate(results, 1):
+ logger.debug(f"\n๊ฒ์ ๊ฒฐ๊ณผ {i}")
+ logger.debug(f"์ ์: {result['score']:.4f}")
+ logger.debug(f"๋ด์ฉ: {result['text'][:200]}...")
diff --git a/code/rag/retriever_elastic.py b/code/rag/retriever_elastic.py
new file mode 100644
index 0000000..597a4d2
--- /dev/null
+++ b/code/rag/retriever_elastic.py
@@ -0,0 +1,190 @@
+import json
+import os
+import re
+from typing import Dict, List, Optional
+import warnings
+
+from dotenv import load_dotenv
+from elasticsearch import Elasticsearch, ElasticsearchWarning
+from loguru import logger
+
+
+# ElasticsearchWarning ๋ฌด์
+warnings.filterwarnings("ignore", category=ElasticsearchWarning)
+
+
+class ElasticsearchRetriever:
+ def __init__(
+ self,
+ index_name: str = "wiki-index",
+ data_path: Optional[str] = None,
+ setting_path: Optional[str] = None,
+ doc_filename: Optional[str] = None,
+ ) -> None:
+ self.index_name = index_name
+ self.client = self._connect_elasticsearch(
+ os.getenv("ELASTICSEARCH_URL"), os.getenv("ELASTICSEARCH_ID"), os.getenv("ELASTICSEARCH_PW")
+ )
+
+ # ๋ฐ์ดํฐ์
๋ก๋ ๋ฐ ์ธ๋ฑ์ค ์ด๊ธฐํ
+ if not self.client.indices.exists(index=self.index_name):
+ if data_path and setting_path and doc_filename:
+ docs = self._load_dataset(os.path.join(data_path, doc_filename))
+ self._initialize_index(setting_path)
+ self._insert_documents(docs)
+ else:
+ raise ValueError(f"์กด์ฌํ์ง ์๋ ์ธ๋ฑ์ค: {index_name}")
+
+ def _connect_elasticsearch(self, url: str, id: str, pw: str) -> Elasticsearch:
+ """ElasticSearch ํด๋ผ์ด์ธํธ ์ฐ๊ฒฐ"""
+ es = Elasticsearch(
+ url,
+ basic_auth=(id, pw),
+ request_timeout=30,
+ max_retries=10,
+ retry_on_timeout=True,
+ verify_certs=False,
+ )
+ logger.info(f"Elasticsearch ์ฐ๊ฒฐ ์ํ: {es.ping()}")
+ return es
+
+ def _load_dataset(self, doc_filename) -> Dict:
+ """๋ฌธ์ ๋ฐ์ดํฐ์
๋ก๋"""
+ with open(doc_filename, "r", encoding="utf-8") as f:
+ return json.load(f)
+
+ def _initialize_index(self, setting_path) -> None:
+ """์ธ๋ฑ์ค ์์ฑ ๋ฐ ์ค์ """
+ with open(setting_path, "r") as f:
+ setting = json.load(f)
+ self.client.indices.create(index=self.index_name, body=setting)
+ logger.info("์ธ๋ฑ์ค ์์ฑ ์๋ฃ")
+
+ def _delete_index(self):
+ if not self.client.indices.exists(index=self.index_name):
+ logger.info("Index doesn't exist.")
+ return
+
+ self.client.indices.delete(index=self.index_name)
+ logger.info("Index deletion has been completed")
+
+ def _insert_documents(self, docs) -> None:
+ """๋ฌธ์ ๋ฐ์ดํฐ bulk ์ฝ์
"""
+
+ def _preprocess(text):
+ text = re.sub(r"\n", " ", text)
+ text = re.sub(r"\\n", " ", text)
+ text = re.sub(r"#", " ", text)
+ text = re.sub(r"\s+", " ", text).strip() # ๋ ๊ฐ ์ด์์ ์ฐ์๋ ๊ณต๋ฐฑ์ ํ๋๋ก ์นํ
+ return text
+
+ bulk_data = []
+ for i, doc in enumerate(docs):
+ # bulk ์์
์ ์ํ ๋ฉํ๋ฐ์ดํฐ
+ bulk_data.append({"index": {"_index": self.index_name, "_id": i}})
+ # ์ค์ ๋ฌธ์ ๋ฐ์ดํฐ
+ bulk_data.append({"title": doc["title"], "text": _preprocess(doc["text"])})
+
+ # 1000๊ฐ ๋จ์๋ก ๋ฒํฌ ์ฝ์
์ํ
+ if (i + 1) % 1000 == 0:
+ try:
+ response = self.client.bulk(body=bulk_data)
+ if response["errors"]:
+ logger.warning(f"{i+1}๋ฒ์งธ ๋ฒํฌ ์ฝ์
์ค ์ผ๋ถ ์ค๋ฅ ๋ฐ์")
+ bulk_data = [] # ๋ฒํฌ ๋ฐ์ดํฐ ์ด๊ธฐํ
+ logger.info(f"{i+1}๊ฐ ๋ฌธ์ ๋ฒํฌ ์ฝ์
์๋ฃ")
+ except Exception as e:
+ logger.error(f"๋ฒํฌ ์ฝ์
์คํจ (์ธ๋ฑ์ค: {i}): {e}")
+ bulk_data = [] # ์ค๋ฅ ๋ฐ์ ์์๋ ๋ฐ์ดํฐ ์ด๊ธฐํ
+
+ # ๋จ์ ๋ฐ์ดํฐ ์ฒ๋ฆฌ
+ if bulk_data:
+ try:
+ response = self.client.bulk(body=bulk_data)
+ if response["errors"]:
+ logger.warning("๋ง์ง๋ง ๋ฒํฌ ์ฝ์
์ค ์ผ๋ถ ์ค๋ฅ ๋ฐ์")
+ except Exception as e:
+ logger.error(f"๋ง์ง๋ง ๋ฒํฌ ์ฝ์
์คํจ: {e}")
+
+ # ์ต์ข
๋ฌธ์ ์ ํ์ธ
+ n_records = self.client.count(index=self.index_name)["count"]
+ logger.info(f"์ด {n_records}๊ฐ ๋ฌธ์ ์ฝ์
์๋ฃ")
+
+ def retrieve(self, query: str, top_k: int = 3) -> List[Dict]:
+ """๋จ์ผ ์ฟผ๋ฆฌ์ ๋ํ ๊ฒ์ ์ํ"""
+ query_body = {"query": {"bool": {"must": [{"match": {"text": query}}]}}}
+
+ response = self.client.search(index=self.index_name, body=query_body, size=top_k)
+
+ results = []
+ for hit in response["hits"]["hits"]:
+ results.append({"text": f"{hit['_source']['title']}: {hit['_source']['text']}", "score": hit["_score"]})
+ return results
+
+ def bulk_retrieve(self, queries: List[str], top_k: int = 3) -> List[List[Dict]]:
+ """์ฌ๋ฌ ์ฟผ๋ฆฌ์ ๋ํ ์ผ๊ด ๊ฒ์ ์ํ (msearch API ์ฌ์ฉ)"""
+ logger.info(f"{len(queries)}๊ฐ ์ฟผ๋ฆฌ ์ผ๊ด ๊ฒ์")
+
+ # msearch API๋ฅผ ์ํ bulk ์ฟผ๋ฆฌ ์ค๋น
+ bulk_query = []
+ for query in queries:
+ # ๋ฉํ๋ฐ์ดํฐ ๋ผ์ธ
+ bulk_query.append({"index": self.index_name})
+ # ์ฟผ๋ฆฌ ๋ผ์ธ
+ bulk_query.append({"query": {"bool": {"must": [{"match": {"text": query}}]}}, "size": top_k})
+
+ try:
+ # msearch API ํธ์ถ
+ response = self.client.msearch(body=bulk_query)
+
+ # ๊ฒฐ๊ณผ ์ฒ๋ฆฌ
+ results = []
+ for response_item in response["responses"]:
+ query_results = []
+ if not response_item.get("error"):
+ for hit in response_item["hits"]["hits"]:
+ query_results.append(
+ {"text": f"{hit['_source']['title']}: {hit['_source']['text']}", "score": hit["_score"]}
+ )
+ results.append(query_results)
+
+ logger.info(f"{len(queries)}๊ฐ ์ฟผ๋ฆฌ ์ผ๊ด ๊ฒ์ ์๋ฃ")
+ return results
+
+ except Exception as e:
+ logger.error(f"Bulk search ์คํจ: {e}")
+ return [[] for _ in queries] # ์๋ฌ ๋ฐ์ ์ ๋น ๊ฒฐ๊ณผ ๋ฐํ
+
+
+if __name__ == "__main__":
+ config_folder = os.path.join(os.path.dirname(__file__), "..", "..", "config")
+ load_dotenv(os.path.join(config_folder, ".env"))
+
+ retriever = ElasticsearchRetriever(
+ data_path="../data/",
+ index_name="wiki-index",
+ setting_path="../config/elastic_setting.json",
+ doc_filename="wiki.json",
+ )
+
+ # ์๋ก์ด ๋ฌธ์ ์ถ๊ฐ ์ฝ์
์์๋ง ์ฌ์ฉ
+ if False:
+ current_count = retriever.client.count(index=retriever.index_name)["count"]
+
+ logger.info("์๋ก์ด ๋ฌธ์ ์ถ๊ฐ ์์")
+ with open("new_wiki.json", "r", encoding="utf-8") as f:
+ new_docs = json.load(f)
+ retriever._insert_documents(new_docs)
+
+ # ๋ฌธ์ ์ถ๊ฐ ํ์ธ
+ new_count = retriever.client.count(index=retriever.index_name)["count"]
+ logger.info(f"๋ฌธ์ ์ถ๊ฐ ์๋ฃ: {current_count} -> {new_count} ({new_count-current_count}๊ฐ ์ถ๊ฐ)")
+
+ # ๋ฌธ์ ๊ฒ์ ํ
์คํธ
+ query = "์ ๋น๋ค ์๋ง ๋ช
์ด ๋๊ถ ์์ ๋ชจ์ฌ ๋ง ๋๋ฌ์ ์์์ ๋ค์ ์ค๋ฆฝํ ๊ฒ์ ์ฒญํ๋, (๊ฐ)์ด/๊ฐ ํฌ๊ฒ ๋
ธํ์ฌ ํ์ฑ๋ถ์ ์กฐ๋ก(็้ท)์ ๋ณ์กธ๋ก ํ์ฌ ๊ธ ํ ๊ฐ ๋ฐ์ผ๋ก ๋ชฐ์๋ด๊ฒ ํ๊ณ ๋๋์ด ์ฒ์ฌ ๊ณณ์ ์์์ ์ฒ ํํ๊ณ ๊ทธ ํ ์ง๋ฅผ ๋ชฐ์ํ์ฌ ๊ด์ ์ํ๊ฒ ํ์๋ค.๏ผ๋ํ๊ณ๋
์ฌ" # noqa: E501
+ results = retriever.retrieve(query, top_k=5)
+
+ for i, result in enumerate(results, 1):
+ logger.debug(f"\n๊ฒ์ ๊ฒฐ๊ณผ {i}")
+ logger.debug(f"์ ์: {result['score']:.4f}")
+ logger.debug(f"๋ด์ฉ: {result['text'][:200]}...")
diff --git a/code/rag/train.py b/code/rag/train.py
new file mode 100644
index 0000000..20a33fe
--- /dev/null
+++ b/code/rag/train.py
@@ -0,0 +1,228 @@
+from copy import deepcopy
+import logging
+import os
+from typing import Tuple
+
+from dpr_data import KorQuadSampler, korquad_collator
+import numpy as np
+import torch
+from tqdm import tqdm
+import transformers
+import wandb
+
+
+# Ensure output directory exists
+os.makedirs("./output", exist_ok=True)
+
+# Set up logging
+os.makedirs("logs", exist_ok=True)
+logging.basicConfig(
+ filename="logs/log.log",
+ level=logging.DEBUG,
+ format="[%(asctime)s | %(funcName)s @ %(pathname)s] %(message)s",
+)
+logger = logging.getLogger() # get root logger
+
+
+class Trainer:
+ """Basic trainer"""
+
+ def __init__(
+ self,
+ model,
+ device,
+ train_dataset,
+ valid_dataset,
+ num_epoch: int,
+ batch_size: int,
+ lr: float,
+ betas: Tuple[float],
+ num_warmup_steps: int,
+ num_training_steps: int,
+ valid_every: int,
+ best_val_ckpt_path: str,
+ ):
+ self.model = model.to(device)
+ self.device = device
+ self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, betas=betas)
+ self.scheduler = transformers.get_linear_schedule_with_warmup(
+ self.optimizer, num_warmup_steps, num_training_steps
+ )
+ self.train_loader = torch.utils.data.DataLoader(
+ dataset=train_dataset.dataset,
+ batch_sampler=KorQuadSampler(train_dataset.dataset, batch_size=batch_size, drop_last=False),
+ collate_fn=lambda x: korquad_collator(x, padding_value=train_dataset.pad_token_id),
+ num_workers=4,
+ )
+ self.valid_loader = torch.utils.data.DataLoader(
+ dataset=valid_dataset.dataset,
+ batch_sampler=KorQuadSampler(valid_dataset.dataset, batch_size=batch_size, drop_last=False),
+ collate_fn=lambda x: korquad_collator(x, padding_value=valid_dataset.pad_token_id),
+ num_workers=4,
+ )
+
+ self.batch_size = batch_size
+ self.num_epoch = num_epoch
+ self.valid_every = valid_every
+ self.lr = lr
+ self.betas = betas
+ self.num_warmup_steps = num_warmup_steps
+ self.num_training_steps = num_training_steps
+ self.best_val_ckpt_path = best_val_ckpt_path
+ self.best_val_optim_path = best_val_ckpt_path.split(".pt")[0] + "_optim.pt"
+
+ self.start_ep = 1
+ self.start_step = 1
+
+ def ibn_loss(self, pred: torch.FloatTensor):
+ """In-batch negative loss calculation."""
+ bsz = pred.size(0)
+ target = torch.arange(bsz).to(self.device)
+ return torch.nn.functional.cross_entropy(pred, target)
+
+ def batch_acc(self, pred: torch.FloatTensor):
+ """Batch accuracy calculation."""
+ bsz = pred.size(0)
+ target = torch.arange(bsz)
+ return (pred.detach().cpu().max(1).indices == target).sum().float() / bsz
+
+ def fit(self):
+ """Train the model."""
+ wandb.init(
+ project="kordpr",
+ entity="lucas01",
+ config={
+ "batch_size": self.batch_size,
+ "lr": self.lr,
+ "betas": self.betas,
+ "num_warmup_steps": self.num_warmup_steps,
+ "num_training_steps": self.num_training_steps,
+ "valid_every": self.valid_every,
+ },
+ )
+ logger.debug("start training")
+ self.model.train() # Set model to training mode
+ global_step_cnt = 0
+ prev_best = None
+ for ep in range(self.start_ep, self.num_epoch + 1):
+ for step, batch in enumerate(tqdm(self.train_loader, desc=f"epoch {ep} batch"), 1):
+ if ep == self.start_ep and step < self.start_step:
+ continue # Skip until the saved checkpoint
+
+ self.model.train() # Set model to training mode
+ global_step_cnt += 1
+ q, q_mask, _, p, p_mask = batch
+ q, q_mask, p, p_mask = (
+ q.to(self.device),
+ q_mask.to(self.device),
+ p.to(self.device),
+ p_mask.to(self.device),
+ )
+ q_emb = self.model(q, q_mask, "query")
+ p_emb = self.model(p, p_mask, "passage")
+ pred = torch.matmul(q_emb, p_emb.T)
+ loss = self.ibn_loss(pred)
+ acc = self.batch_acc(pred)
+
+ self.optimizer.zero_grad()
+ loss.backward()
+ self.optimizer.step()
+ self.scheduler.step()
+ log = {
+ "epoch": ep,
+ "step": step,
+ "global_step": global_step_cnt,
+ "train_step_loss": loss.cpu().item(),
+ "current_lr": float(self.scheduler.get_last_lr()[0]),
+ "step_acc": acc,
+ }
+ if global_step_cnt % self.valid_every == 0:
+ eval_dict = self.evaluate()
+ log.update(eval_dict)
+ if prev_best is None or eval_dict["valid_loss"] < prev_best: # Save best validation model
+ self.save_training_state(log)
+ wandb.log(log)
+
+ def evaluate(self):
+ """Evaluate the model."""
+ self.model.eval() # Set model to evaluation mode
+ loss_list = []
+ sample_cnt = 0
+ valid_acc = 0
+ with torch.no_grad():
+ for batch in self.valid_loader:
+ q, q_mask, _, p, p_mask = batch
+ q, q_mask, p, p_mask = (
+ q.to(self.device),
+ q_mask.to(self.device),
+ p.to(self.device),
+ p_mask.to(self.device),
+ )
+ q_emb = self.model(q, q_mask, "query")
+ p_emb = self.model(p, p_mask, "passage")
+ pred = torch.matmul(q_emb, p_emb.T)
+ loss = self.ibn_loss(pred)
+ step_acc = self.batch_acc(pred)
+
+ bsz = q.size(0)
+ sample_cnt += bsz
+ valid_acc += step_acc * bsz
+ loss_list.append(loss.cpu().item() * bsz)
+ return {
+ "valid_loss": np.array(loss_list).sum() / float(sample_cnt),
+ "valid_acc": valid_acc / float(sample_cnt),
+ }
+
+ def save_training_state(self, log_dict: dict) -> None:
+ """Save model, optimizer, and other training states."""
+ checkpoint_path = os.path.join("./output", self.best_val_ckpt_path)
+ self.model.checkpoint(checkpoint_path)
+ training_state = {
+ "optimizer_state": deepcopy(self.optimizer.state_dict()),
+ "scheduler_state": deepcopy(self.scheduler.state_dict()),
+ }
+ training_state.update(log_dict)
+ optim_path = os.path.join("./output", self.best_val_optim_path)
+ torch.save(training_state, optim_path)
+ logger.debug(f"Saved optimizer/scheduler state into {optim_path}")
+
+ def load_training_state(self) -> None:
+ """Load model, optimizer, and other training states."""
+ checkpoint_path = os.path.join("./output", self.best_val_ckpt_path)
+ if os.path.exists(checkpoint_path):
+ self.model.load(checkpoint_path)
+ optim_path = os.path.join("./output", self.best_val_optim_path)
+ training_state = torch.load(optim_path)
+ logger.debug(f"Loaded optimizer/scheduler state from {optim_path}")
+ self.optimizer.load_state_dict(training_state["optimizer_state"])
+ self.scheduler.load_state_dict(training_state["scheduler_state"])
+ self.start_ep = training_state["epoch"]
+ self.start_step = training_state["step"]
+ logger.debug(f"Resumed training from epoch {self.start_ep} / step {self.start_step}")
+ else:
+ logger.debug("No checkpoint found, starting training from scratch.")
+
+
+# if __name__ == "__main__":
+# device = torch.device("cuda:0")
+# model = KobertBiEncoder()
+# train_dataset = KorQuadDataset("./data/KorQuAD_v1.0_train.json")
+# valid_dataset = KorQuadDataset("./data/KorQuAD_v1.0_dev.json")
+# my_trainer = Trainer(
+# model=model,
+# device=device,
+# train_dataset=train_dataset,
+# valid_dataset=valid_dataset,
+# num_epoch=1,
+# batch_size=128 - 32,
+# lr=1e-5,
+# betas=(0.9, 0.99),
+# num_warmup_steps=1000,
+# num_training_steps=100000,
+# valid_every=30,
+# best_val_ckpt_path="my_model.pt",
+# )
+# my_trainer.load_training_state()
+# my_trainer.fit() # Start training
+# # eval_dict = my_trainer.evaluate() # If you want to evaluate after training
+# # print(eval_dict)
diff --git a/code/rag/trainer.py b/code/rag/trainer.py
new file mode 100644
index 0000000..79502eb
--- /dev/null
+++ b/code/rag/trainer.py
@@ -0,0 +1,257 @@
+import os
+import sys
+
+
+# ํ์ฌ ์ฝ๋๊ฐ ์๋ ๋๋ ํ ๋ฆฌ ๊ธฐ์ค์ผ๋ก ์์ ๋๋ ํ ๋ฆฌ๋ฅผ `sys.path`์ ์ถ๊ฐ
+sys.path.append(os.path.abspath(os.path.dirname(__file__)))
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
+
+from copy import deepcopy
+import logging
+from typing import Tuple
+
+from dpr_data import KorQuadDataset, KorQuadSampler, korquad_collator
+from encoder import KobertBiEncoder
+import numpy as np
+import torch
+from tqdm import tqdm
+import transformers
+
+
+os.makedirs("logs", exist_ok=True)
+logging.basicConfig(
+ filename="logs/log.log",
+ level=logging.DEBUG,
+ format="[%(asctime)s | %(funcName)s @ %(pathname)s] %(message)s",
+)
+logger = logging.getLogger() # get root logger
+
+
+class Trainer:
+ """basic trainer"""
+
+ def __init__(
+ self,
+ model,
+ device,
+ train_dataset,
+ valid_dataset,
+ num_epoch: int,
+ batch_size: int,
+ lr: float,
+ betas: Tuple[float],
+ num_warmup_steps: int,
+ num_training_steps: int,
+ valid_every: int,
+ best_val_ckpt_path: str,
+ ):
+ self.model = model.to(device)
+ self.device = device
+ self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, betas=betas)
+ self.scheduler = transformers.get_linear_schedule_with_warmup(
+ self.optimizer, num_warmup_steps, num_training_steps
+ )
+ self.train_loader = torch.utils.data.DataLoader(
+ dataset=train_dataset.dataset,
+ batch_sampler=KorQuadSampler(train_dataset.dataset, batch_size=batch_size, drop_last=False),
+ collate_fn=lambda x: korquad_collator(x, padding_value=train_dataset.pad_token_id),
+ num_workers=4,
+ )
+ self.valid_loader = torch.utils.data.DataLoader(
+ dataset=valid_dataset.dataset,
+ batch_sampler=KorQuadSampler(valid_dataset.dataset, batch_size=batch_size, drop_last=False),
+ collate_fn=lambda x: korquad_collator(x, padding_value=valid_dataset.pad_token_id),
+ num_workers=4,
+ )
+
+ self.batch_size = batch_size
+ self.num_epoch = num_epoch
+ self.valid_every = valid_every
+ self.lr = lr
+ self.betas = betas
+ self.num_warmup_steps = num_warmup_steps
+ self.num_training_steps = num_training_steps
+ self.best_val_ckpt_path = best_val_ckpt_path
+ self.best_val_optim_path = best_val_ckpt_path.split(".pt")[0] + "_optim.pt"
+
+ self.start_ep = 1
+ self.start_step = 1
+
+ def ibn_loss(self, pred: torch.FloatTensor):
+ """in-batch negative๋ฅผ ํ์ฉํ batch์ loss๋ฅผ ๊ณ์ฐํฉ๋๋ค.
+ pred : bsz x bsz ๋๋ bsz x bsz*2์ logit ๊ฐ์ ๊ฐ์ง. ํ์๋ hard negative๋ฅผ ํฌํจํ๋ ๊ฒฝ์ฐ.
+ """
+ bsz = pred.size(0)
+ target = torch.arange(bsz).to(self.device) # ์ฃผ๋๊ฐ์ ์ด answer
+ return torch.nn.functional.cross_entropy(pred, target)
+
+ def batch_acc(self, pred: torch.FloatTensor):
+ """batch ๋ด์ accuracy๋ฅผ ๊ณ์ฐํฉ๋๋ค."""
+ bsz = pred.size(0)
+ target = torch.arange(bsz) # ์ฃผ๋๊ฐ์ ์ด answer
+ return (pred.detach().cpu().max(1).indices == target).sum().float() / bsz
+
+ def fit(self):
+ """๋ชจ๋ธ์ ํ์ตํฉ๋๋ค."""
+ # wandb.init(
+ # project="personal",
+ # entity="gayean01",
+ # config={
+ # "batch_size": self.batch_size,
+ # "lr": self.lr,
+ # "betas": self.betas,
+ # "num_warmup_steps": self.num_warmup_steps,
+ # "num_training_steps": self.num_training_steps,
+ # "valid_every": self.valid_every,
+ # },
+ # )
+ logger.debug("start training")
+ self.model.train() # ํ์ต๋ชจ๋
+ global_step_cnt = 0
+ prev_best = None
+ for ep in range(self.start_ep, self.num_epoch + 1):
+ for step, batch in enumerate(tqdm(self.train_loader, desc=f"epoch {ep} batch"), 1):
+ if ep == self.start_ep and step < self.start_step:
+ continue # ์ค๊ฐ๋ถํฐ ํ์ต์ํค๋ ๊ฒฝ์ฐ ํด๋น ์ง์ ๊น์ง ๋ณต์
+
+ self.model.train() # ํ์ต ๋ชจ๋
+ global_step_cnt += 1
+ q, q_mask, _, p, p_mask = batch
+ q, q_mask, p, p_mask = (
+ q.to(self.device),
+ q_mask.to(self.device),
+ p.to(self.device),
+ p_mask.to(self.device),
+ )
+ q_emb = self.model(q, q_mask, "query") # bsz x bert_dim
+ p_emb = self.model(p, p_mask, "passage") # bsz x bert_dim
+ pred = torch.matmul(q_emb, p_emb.T) # bsz x bsz
+ loss = self.ibn_loss(pred)
+ acc = self.batch_acc(pred)
+
+ self.optimizer.zero_grad()
+ loss.backward()
+ self.optimizer.step()
+ self.scheduler.step()
+ log = {
+ "epoch": ep,
+ "step": step,
+ "global_step": global_step_cnt,
+ "train_step_loss": loss.cpu().item(),
+ "current_lr": float(self.scheduler.get_last_lr()[0]), # parameter group 1๊ฐ์ด๋ฏ๋ก
+ "step_acc": acc,
+ }
+ if global_step_cnt % self.valid_every == 0:
+ eval_dict = self.evaluate()
+ log.update(eval_dict)
+ if prev_best is None or eval_dict["valid_loss"] < prev_best: # best val loss์ธ ๊ฒฝ์ฐ ์ ์ฅ
+ # self.model.checkpoint(self.best_val_ckpt_path)
+ self.save_training_state(log)
+ # wandb.log(log)
+
+ def evaluate(self):
+ """๋ชจ๋ธ์ ํ๊ฐํฉ๋๋ค."""
+ self.model.eval() # ํ๊ฐ ๋ชจ๋
+ loss_list = []
+ sample_cnt = 0
+ valid_acc = 0
+ with torch.no_grad():
+ for batch in self.valid_loader:
+ q, q_mask, _, p, p_mask = batch
+ q, q_mask, p, p_mask = (
+ q.to(self.device),
+ q_mask.to(self.device),
+ p.to(self.device),
+ p_mask.to(self.device),
+ )
+ q_emb = self.model(q, q_mask, "query") # bsz x bert_dim
+ p_emb = self.model(p, p_mask, "passage") # bsz x bert_dim
+ pred = torch.matmul(q_emb, p_emb.T) # bsz x bsz
+ loss = self.ibn_loss(pred)
+ step_acc = self.batch_acc(pred)
+
+ bsz = q.size(0)
+ sample_cnt += bsz
+ valid_acc += step_acc * bsz
+ loss_list.append(loss.cpu().item() * bsz)
+ valid_loss = np.array(loss_list).sum() / float(sample_cnt)
+ valid_acc = valid_acc / float(sample_cnt)
+
+ # ์ฝ์์ ์ถ๋ ฅ
+ logger.info(f"Validation Loss: {valid_loss:.4f}, Validation Accuracy: {valid_acc:.4f}")
+ return {
+ "valid_loss": np.array(loss_list).sum() / float(sample_cnt),
+ "valid_acc": valid_acc / float(sample_cnt),
+ }
+
+ def save_training_state(self, log_dict: dict) -> None:
+ """๋ชจ๋ธ, optimizer์ ๊ธฐํ ์ ๋ณด๋ฅผ ์ ์ฅํฉ๋๋ค"""
+ self.model.checkpoint(self.best_val_ckpt_path)
+ training_state = {
+ "optimizer_state": deepcopy(self.optimizer.state_dict()),
+ "scheduler_state": deepcopy(self.scheduler.state_dict()),
+ }
+ training_state.update(log_dict)
+ torch.save(training_state, self.best_val_optim_path)
+ logger.debug(f"saved optimizer/scheduler state into {self.best_val_optim_path}")
+
+ def load_training_state(self) -> None:
+ """๋ชจ๋ธ, optimizer์ ๊ธฐํ ์ ๋ณด๋ฅผ ๋ก๋ํฉ๋๋ค"""
+ self.model.load(self.best_val_ckpt_path)
+ training_state = torch.load(self.best_val_optim_path)
+ logger.debug(f"loaded optimizer/scheduler state from {self.best_val_optim_path}")
+ self.optimizer.load_state_dict(training_state["optimizer_state"])
+ self.scheduler.load_state_dict(training_state["scheduler_state"])
+ self.start_ep = training_state["epoch"]
+ self.start_step = training_state["step"]
+ logger.debug(f"resume training from epoch {self.start_ep} / step {self.start_step}")
+
+
+# ๋ชจ๋ธ ์กด์ฌ ์ฌ๋ถ ํ์ธ ํจ์
+def check_if_model_exists(model_path: str):
+ """๋ชจ๋ธ ์ฒดํฌํฌ์ธํธ๊ฐ ์กด์ฌํ๋์ง ํ์ธํ๋ ํจ์"""
+ return os.path.exists(model_path)
+
+
+# ๋ฉ์ธ ์คํ
+if __name__ == "__main__":
+ # ๋ชจ๋ธ ๊ฒฝ๋ก ์ค์
+ model_path = "./output/my_model.pt"
+
+ # ๋ชจ๋ธ์ด ์์ผ๋ฉด ํ์ต ์์
+ if not check_if_model_exists(model_path):
+ logger.info(f"๋ชจ๋ธ '{model_path}'์ด ์กด์ฌํ์ง ์์ต๋๋ค. ํ์ต์ ์์ํฉ๋๋ค.")
+
+ # ํ์ต์ ์ํ ์ค๋น
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ model = KobertBiEncoder()
+ train_dataset = KorQuadDataset("./data/KorQuAD_v1.0_train.json")
+ valid_dataset = KorQuadDataset("./data/KorQuAD_v1.0_dev.json")
+
+ # Trainer ๊ฐ์ฒด ์์ฑ
+ my_trainer = Trainer(
+ model=model,
+ device=device,
+ train_dataset=train_dataset,
+ valid_dataset=valid_dataset,
+ num_epoch=1, # ํ์ต epoch ์
+ batch_size=32, # ๋ฐฐ์น ํฌ๊ธฐ
+ lr=1e-5,
+ betas=(0.9, 0.99),
+ num_warmup_steps=100,
+ num_training_steps=1000,
+ valid_every=100,
+ best_val_ckpt_path=model_path,
+ )
+
+ # ํ์ต ์ํ
+ my_trainer.fit()
+ eval_dict = my_trainer.evaluate()
+ logger.info(eval_dict)
+
+ # ๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ ์์ฑ ๋ฐ ์ ์ฅ
+ os.makedirs("output", exist_ok=True)
+ torch.save(model.state_dict(), model_path)
+ logger.info(f"ํ์ต ์๋ฃ. ๋ชจ๋ธ์ด '{model_path}'์ ์ ์ฅ๋์์ต๋๋ค.")
+ else:
+ logger.info(f"๋ชจ๋ธ '{model_path}'์ด ์ด๋ฏธ ์กด์ฌํฉ๋๋ค. ํ์ต์ ๊ฑด๋๋๋๋ค.")
diff --git a/code/rag/utils.py b/code/rag/utils.py
new file mode 100644
index 0000000..084b1cd
--- /dev/null
+++ b/code/rag/utils.py
@@ -0,0 +1,36 @@
+from glob import glob
+import math
+import typing
+
+import torch
+
+
+def get_wiki_filepath(data_dir):
+ return glob(f"{data_dir}/*/wiki_*")
+
+
+def wiki_worker_init(worker_id):
+ worker_info = torch.utils.data.get_worker_info()
+ dataset = worker_info.dataset
+ # print(dataset)
+ # dataset =
+ overall_start = dataset.start
+ overall_end = dataset.end
+ split_size = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
+ worker_id = worker_info.id
+ # end_idx = min((worker_id+1) * split_size, len(dataset.data))
+ dataset.start = overall_start + worker_id * split_size
+ dataset.end = min(dataset.start + split_size, overall_end) # index error ๋ฐฉ์ง
+
+
+def get_passage_file(p_id_list: typing.List[int]) -> str:
+ """passage id๋ฅผ ๋ฐ์์ ํด๋น๋๋ ํ์ผ ์ด๋ฆ์ ๋ฐํํฉ๋๋ค."""
+ target_file = None
+ p_id_max = max(p_id_list)
+ p_id_min = min(p_id_list)
+ for f in glob("processed_passages/*.p"):
+ s, e = f.split("/")[1].split(".")[0].split("-")
+ s, e = int(s), int(e)
+ if p_id_min >= s and p_id_max <= e:
+ target_file = f
+ return target_file
diff --git a/code/split.py b/code/split.py
new file mode 100644
index 0000000..361573a
--- /dev/null
+++ b/code/split.py
@@ -0,0 +1,85 @@
+from ast import literal_eval
+
+from loguru import logger
+import pandas as pd
+
+
+def load_data(file_path):
+ data = pd.read_csv(file_path)
+ records = []
+ for _, row in data.iterrows():
+ problems = literal_eval(row["problems"])
+ record = {
+ "id": row["id"],
+ "paragraph": row["paragraph"],
+ "question": problems["question"],
+ "choices": problems["choices"],
+ "answer": problems.get("answer", None),
+ }
+ records.append(record)
+ logger.debug(records[0]) # ์ฒซ ๋ฒ์งธ ๋ ์ฝ๋ ์ถ๋ ฅ (๋๋ฒ๊น
์ฉ)
+ return data, records
+
+
+def classify_questions(records):
+ social_keywords = ["ใฑ.", "ใ ", "์์", "์ ๊ธ" "๋จ๋ฝ", "๋ณธ๋ฌธ", "๋ฐ์ค ์น", "(๊ฐ)", "๋ค์", "์๊ธฐ"]
+ classifications = []
+
+ for record in records:
+ question = record["question"]
+ paragraph = record["paragraph"]
+
+ # ์ฌํ ์์ญ ํ๋จ
+ contains_social_keywords = any(keyword in question for keyword in social_keywords)
+
+ # ๊ฐ ์ ํ์ง๊ฐ ๋ณธ๋ฌธ์ ํฌํจ๋์ด ์๋์ง ํ์ธ
+ choices_found_in_paragraph = {choice: choice in paragraph for choice in record["choices"]}
+
+ # ์ ๋ต์ด ํฌํจ๋ ์ ํ์ง ์ฐพ๊ธฐ
+ answer_index = record["answer"]
+ answer_found_in_paragraph = False
+
+ if answer_index is not None and 0 <= answer_index < len(record["choices"]):
+ answer = record["choices"][answer_index] # ์ ๋ต ์ ํ์ง
+ answer_found_in_paragraph = choices_found_in_paragraph.get(answer, False)
+
+ # if contains_social_keywords and not answer_found_in_paragraph:
+ # classification = '์ฌํ'
+ if contains_social_keywords:
+ classification = "์ฌํ"
+ elif not contains_social_keywords and answer_found_in_paragraph:
+ classification = "๊ตญ์ด"
+ else:
+ classification = "๋ถํ์ค" # ๋ ์กฐ๊ฑด ๋ชจ๋ ํด๋นํ์ง ์๊ฑฐ๋ ๋ชจ๋ ํด๋นํ๋ ๊ฒฝ์ฐ
+
+ classifications.append(
+ {
+ "id": record["id"],
+ "classification": classification,
+ "contains_social_keywords": contains_social_keywords,
+ "answer_found_in_paragraph": answer_found_in_paragraph,
+ "choices_found_in_paragraph": choices_found_in_paragraph, # ์ ํ์ง ํฌํจ ์ฌ๋ถ ์ถ๊ฐ
+ }
+ )
+
+ return classifications
+
+
+def main():
+ file_path = "../data/train.csv" # ํ์ผ ๊ฒฝ๋ก ์ค์
+ data, records = load_data(file_path)
+
+ classifications = classify_questions(records)
+
+ # ๊ฒฐ๊ณผ๋ฅผ ๋ฐ์ดํฐํ๋ ์์ผ๋ก ๋ณํ
+ result_df = pd.DataFrame(classifications)
+
+ # ๊ฒฐ๊ณผ๋ฅผ CSV ํ์ผ๋ก ์ ์ฅ
+ output_file_path = "../data/classification_results.csv" # ์ถ๋ ฅ ํ์ผ ๊ฒฝ๋ก ์ค์
+ result_df.to_csv(output_file_path, index=False, encoding="utf-8-sig") # CSV๋ก ์ ์ฅ
+
+ logger.debug(f"๊ฒฐ๊ณผ๊ฐ {output_file_path}์ ์ ์ฅ๋์์ต๋๋ค.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/code/trainer.py b/code/trainer.py
new file mode 100644
index 0000000..5304a5c
--- /dev/null
+++ b/code/trainer.py
@@ -0,0 +1,115 @@
+import evaluate
+import numpy as np
+from peft import LoraConfig
+import torch
+from transformers import EarlyStoppingCallback
+from trl import DataCollatorForCompletionOnlyLM, SFTConfig, SFTTrainer
+
+
+class CustomTrainer:
+ def __init__(self, training_config, model, tokenizer, train_dataset, eval_dataset):
+ self.model = model
+ self.tokenizer = tokenizer
+ self.train_dataset = train_dataset
+ self.eval_dataset = eval_dataset
+ self.training_config = training_config
+ self.acc_metric = evaluate.load("accuracy")
+ self.int_output_map = {"1": 0, "2": 1, "3": 2, "4": 3, "5": 4}
+
+ def train(self):
+ trainer = self._setup_trainer()
+ trainer.train()
+ return trainer.model
+
+ def _setup_trainer(self):
+ # ๋ฐ์ดํฐ ์ฝ๋ ์ดํฐ ์ค์
+ data_collator = DataCollatorForCompletionOnlyLM(
+ response_template=self.training_config["response_template"],
+ tokenizer=self.tokenizer,
+ )
+
+ # LoRA ์ค์
+ peft_config = LoraConfig(
+ r=self.training_config["lora"]["r"],
+ lora_alpha=self.training_config["lora"]["lora_alpha"],
+ lora_dropout=self.training_config["lora"]["lora_dropout"],
+ target_modules=self.training_config["lora"]["target_modules"],
+ bias=self.training_config["lora"]["bias"],
+ task_type=self.training_config["lora"]["task_type"],
+ )
+
+ # SFT ์ค์
+ sft_config = SFTConfig(
+ do_train=self.training_config["params"]["do_train"],
+ do_eval=self.training_config["params"]["do_eval"],
+ lr_scheduler_type=self.training_config["params"]["lr_scheduler_type"],
+ max_seq_length=self.training_config["params"]["max_seq_length"],
+ per_device_train_batch_size=self.training_config["params"]["per_device_train_batch_size"],
+ per_device_eval_batch_size=self.training_config["params"]["per_device_eval_batch_size"],
+ gradient_accumulation_steps=self.training_config["params"]["gradient_accumulation_steps"],
+ gradient_checkpointing=self.training_config["params"]["gradient_checkpointing"],
+ max_grad_norm=self.training_config["params"]["max_grad_norm"],
+ num_train_epochs=self.training_config["params"]["num_train_epochs"],
+ learning_rate=self.training_config["params"]["learning_rate"],
+ weight_decay=self.training_config["params"]["weight_decay"],
+ optim=self.training_config["params"]["optim"],
+ logging_strategy=self.training_config["params"]["logging_strategy"],
+ save_strategy=self.training_config["params"]["save_strategy"],
+ eval_strategy=self.training_config["params"]["eval_strategy"],
+ logging_steps=self.training_config["params"]["logging_steps"],
+ save_steps=self.training_config["params"]["save_steps"],
+ eval_steps=self.training_config["params"]["eval_steps"],
+ save_total_limit=self.training_config["params"]["save_total_limit"],
+ save_only_model=self.training_config["params"]["save_only_model"],
+ load_best_model_at_end=self.training_config["params"]["load_best_model_at_end"],
+ report_to=self.training_config["params"]["report_to"],
+ run_name=self.training_config["params"]["run_name"],
+ output_dir=self.training_config["params"]["output_dir"],
+ overwrite_output_dir=self.training_config["params"]["overwrite_output_dir"],
+ metric_for_best_model=self.training_config["params"]["metric_for_best_model"],
+ )
+
+ return SFTTrainer(
+ model=self.model,
+ train_dataset=self.train_dataset,
+ eval_dataset=self.eval_dataset,
+ data_collator=data_collator,
+ tokenizer=self.tokenizer,
+ peft_config=peft_config,
+ args=sft_config,
+ compute_metrics=self._compute_metrics,
+ preprocess_logits_for_metrics=self._preprocess_logits_for_metrics,
+ callbacks=[
+ EarlyStoppingCallback(
+ early_stopping_patience=self.training_config["params"]["early_stop_patience"],
+ early_stopping_threshold=self.training_config["params"]["early_stop_threshold"],
+ )
+ ],
+ )
+
+ # ๋ชจ๋ธ์ logits๋ฅผ ์กฐ์ ํ์ฌ ์ ๋ต ํ ํฐ ๋ถ๋ถ๋ง ์ถ๋ ฅํ๋๋ก ์ค์
+ def _preprocess_logits_for_metrics(self, logits, labels):
+ logits = logits if not isinstance(logits, tuple) else logits[0]
+ logit_idx = [
+ self.tokenizer.vocab["1"],
+ self.tokenizer.vocab["2"],
+ self.tokenizer.vocab["3"],
+ self.tokenizer.vocab["4"],
+ self.tokenizer.vocab["5"],
+ ]
+ return logits[:, -2, logit_idx] # -2: answer token, -1: eos token
+
+ # metric ๊ณ์ฐ ํจ์
+ def _compute_metrics(self, evaluation_result):
+ logits, labels = evaluation_result
+ # ํ ํฐํ๋ ๋ ์ด๋ธ ๋์ฝ๋ฉ
+ labels = np.where(labels != -100, labels, self.tokenizer.pad_token_id)
+ labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
+ labels = [x.split("")[0].strip() for x in labels]
+ labels = [self.int_output_map[x] for x in labels]
+
+ # softmax ํจ์๋ฅผ ์ฌ์ฉํ์ฌ logits ๋ณํ
+ probs = torch.nn.functional.softmax(torch.tensor(logits), dim=-1)
+ predictions = np.argmax(probs, axis=-1)
+
+ return self.acc_metric.compute(predictions=predictions, references=labels)
diff --git a/code/utils/__init__.py b/code/utils/__init__.py
new file mode 100644
index 0000000..04618c3
--- /dev/null
+++ b/code/utils/__init__.py
@@ -0,0 +1,13 @@
+"""
+ํ๋ก์ ํธ ์ ๋ฐ์ ์ฌ์ฉํ๋ ์ ํธ๋ฆฌํฐ ๋ชจ๋์
๋๋ค.
+
+## ์ฃผ์ ๊ธฐ๋ฅ
+- hf_manager.py: ํ๊น
ํ์ด์ค์ ๋ชจ๋ธ/๋ฐ์ดํฐ ์
๋ก๋
+- gdrive_manager.py: config & output์ ๊ตฌ๊ธ ๋๋ผ์ด๋ธ๋ก ์๋ ์
๋ก๋
+- common.py: ์ธ์ ๋ฐ ๋ก๊น
์ค์ ์ ์ํ ํจ์ ๋ชจ์
+
+"""
+
+from .common import create_experiment_filename, load_config, load_env_file, log_config, set_logger, set_seed, timer
+from .gdrive_manager import GoogleDriveManager
+from .hf_manager import HuggingFaceHubManager
diff --git a/code/utils/common.py b/code/utils/common.py
new file mode 100644
index 0000000..ac5aa5b
--- /dev/null
+++ b/code/utils/common.py
@@ -0,0 +1,106 @@
+import argparse
+from contextlib import contextmanager
+from datetime import datetime
+import os
+import random
+import time
+
+from dotenv import load_dotenv
+from loguru import logger
+import numpy as np
+import torch
+import yaml
+from zoneinfo import ZoneInfo
+
+
+# ์ฝ๋ ์ ์ญ์์ ์ฒซ ์คํ ์์ ์ ํ์์คํฌํ๋ฅผ ๋์ผํ๊ฒ ์ฌ์ฉ
+CURRENT_TIME = None
+
+
+def set_seed(seed=42):
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed) # if use multi-GPU
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+ np.random.seed(seed)
+ random.seed(seed)
+
+
+def load_config():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config", type=str, default="config.yaml")
+ args = parser.parse_args()
+
+ with open(os.path.join("../config", args.config), encoding="utf-8") as f:
+ config = yaml.safe_load(f)
+ return config
+
+
+def load_env_file(filepath="../config/.env"):
+ try:
+ # .env ํ์ผ ๋ก๋ ์๋
+ if load_dotenv(filepath):
+ logger.debug(f".env ํ์ผ์ ์ฑ๊ณต์ ์ผ๋ก ๋ก๋ํ์ต๋๋ค: {filepath}")
+ else:
+ raise FileNotFoundError # ํ์ผ์ด ์์ผ๋ฉด ์์ธ ๋ฐ์
+ except FileNotFoundError:
+ logger.debug(f"๊ฒฝ๊ณ : ์ง์ ๋ .env ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค: {filepath}")
+ except Exception as e:
+ logger.debug(f"์ค๋ฅ ๋ฐ์: .env ํ์ผ ๋ก๋ ์ค ์์ธ๊ฐ ๋ฐ์ํ์ต๋๋ค: {e}")
+
+
+def set_logger(log_file="../log/file.log", log_level="DEBUG"):
+ # ๋ก๊ฑฐ ์ค์
+ logger.add(
+ log_file,
+ format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}",
+ level=log_level,
+ rotation="12:00", # ๋งค์ผ 12์์ ์๋ก์ด ๋ก๊ทธ ํ์ผ ์์ฑ
+ retention="7 days", # 7์ผ ํ ๋ก๊ทธ ์ ๊ฑฐ
+ )
+
+
+# config ํ์ธ
+def log_config(config, depth=0):
+ if depth == 0:
+ print("*" * 40)
+ for k, v in config.items():
+ prefix = ["\t" * depth, k, ":"]
+
+ if isinstance(v, dict):
+ print(*prefix)
+ log_config(v, depth + 1)
+ else:
+ prefix.append(v)
+ print(*prefix)
+ if depth == 0:
+ print("*" * 40)
+
+
+def get_current_time():
+ global CURRENT_TIME
+ if CURRENT_TIME is None:
+ CURRENT_TIME = datetime.now(ZoneInfo("Asia/Seoul")).strftime("%m%d%H%M")
+ return CURRENT_TIME
+
+
+def create_experiment_filename(config):
+ if config is None:
+ config = load_config()
+ username = config["exp"]["username"]
+ base_model = config["model"]["base_model"].replace("/", "_")
+ train_path = config["data"]["train_path"]
+ train_name = os.path.splitext(os.path.basename(train_path))[0]
+ num_train_epochs = config["training"]["params"]["num_train_epochs"]
+ learning_rate = config["training"]["params"]["learning_rate"]
+ current_time = get_current_time()
+
+ return f"{username}_{base_model}_{train_name}_{num_train_epochs}_{learning_rate}_{current_time}"
+
+
+@contextmanager
+def timer(name):
+ t0 = time.time()
+ yield
+ logger.debug(f"[{name}] done in {time.time() - t0:.3f} s")
diff --git a/code/utils/gdrive_manager.py b/code/utils/gdrive_manager.py
new file mode 100755
index 0000000..a882613
--- /dev/null
+++ b/code/utils/gdrive_manager.py
@@ -0,0 +1,196 @@
+import io
+import json
+import os.path
+
+from dotenv import load_dotenv
+from google.auth.transport.requests import Request
+from google.oauth2.credentials import Credentials
+from google_auth_oauthlib.flow import InstalledAppFlow
+from googleapiclient.discovery import build
+from googleapiclient.http import MediaFileUpload, MediaIoBaseUpload
+from loguru import logger
+import pandas as pd
+
+
+SCOPES = [
+ "https://www.googleapis.com/auth/drive.file",
+ "https://www.googleapis.com/auth/drive",
+]
+
+
+class GoogleDriveManager:
+ def __init__(self):
+ config_folder = os.path.join(os.path.dirname(__file__), "..", "..", "config")
+ load_dotenv(os.path.join(config_folder, ".env"))
+ self.config_folder = config_folder
+ self.root_folder_id = os.getenv("GDRIVE_FOLDER_ID")
+ self.credentials = os.getenv("GDRIVE_CREDENTIALS")
+ self.token = os.getenv("GDRIVE_TOKEN")
+ self.is_create_token = os.getenv("GDRIVE_CREATE_TOKEN")
+
+ # ํ๊ฒฝ ๋ณ์ ๊ฒ์ฆ ์ถ๊ฐ
+ if not all([self.root_folder_id, self.credentials, self.token]):
+ logger.error(f"ํ์ ํ๊ฒฝ ๋ณ์๊ฐ ์ค์ ๋์ง ์์์ต๋๋ค. {[self.root_folder_id, self.credentials, self.token]}")
+ raise ValueError("ํ์ ํ๊ฒฝ ๋ณ์๊ฐ ์ค์ ๋์ง ์์์ต๋๋ค.")
+
+ self.service = self.get_drive_service()
+
+ def get_drive_service(self):
+ creds = None
+ if os.path.exists(self.token):
+ creds = Credentials.from_authorized_user_file(self.token, SCOPES)
+
+ if not creds or not creds.valid:
+ if creds and creds.expired and creds.refresh_token:
+ creds.refresh(Request())
+ elif self.is_create_token == "true":
+ flow = InstalledAppFlow.from_client_secrets_file(self.credentials, scopes=SCOPES)
+ creds = flow.run_local_server(port=0, open_browser=False)
+ # ๋ฆฌํ๋ ์ ํ ํฐ ์ ์ฅ
+ with open(self.token, "w") as token_file:
+ token_data = {
+ "refresh_token": creds.refresh_token,
+ "token": creds.token,
+ "token_uri": creds.token_uri,
+ "client_id": creds.client_id,
+ "client_secret": creds.client_secret,
+ "scopes": creds.scopes,
+ }
+ json.dump(token_data, token_file)
+ logger.info("๊ตฌ๊ธ ๋๋ผ์ด๋ธ ํ ํฐ์ด ๊ฐฑ์ ๋์์ต๋๋ค.")
+ else:
+ logger.error("๊ตฌ๊ธ ๋๋ผ์ด๋ธ ์
๋ก๋ ์คํจ. ํ ํฐ์ ๊ฐฑ์ ์ ์์ฒญํ์ธ์.")
+ with open(self.token, "w") as token:
+ token.write(creds.to_json())
+ return build("drive", "v3", credentials=creds)
+
+ def find_folder_id_by_name(self, folder_name, parent_folder_id=None):
+ """ํด๋๋ช
์ผ๋ก ํด๋ ID ์ฐพ๊ธฐ"""
+ if not parent_folder_id:
+ parent_folder_id = self.root_folder_id
+
+ # ํน์ ํด๋๋ช
๊ณผ ์ ํํ ์ผ์นํ๋ ํด๋ ๊ฒ์ ์ฟผ๋ฆฌ
+ query = f"name='{folder_name}' and "
+ query += f"'{parent_folder_id}' in parents and "
+ query += "mimeType='application/vnd.google-apps.folder' and "
+ query += "trashed=false"
+
+ try:
+ results = (
+ self.service.files()
+ .list(
+ q=query,
+ spaces="drive",
+ fields="files(id, name)",
+ pageSize=1, # ์ฒซ ๋ฒ์งธ ์ผ์นํ๋ ํด๋๋ง ํ์
+ )
+ .execute()
+ )
+
+ files = results.get("files", [])
+
+ if not files:
+ logger.info(f"ํด๋๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค: {folder_name}")
+ return None
+
+ return files[0]["id"]
+
+ except Exception as e:
+ logger.info(f"ํด๋ ๊ฒ์ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}")
+ return None
+
+ def list_folder_files(self, folder_id=None):
+ """ํด๋ ๋ด ํ์ผ ๋ชฉ๋ก ์กฐํ"""
+ if not folder_id:
+ folder_id = self.root_folder_id
+ query = f"'{folder_id}' in parents and trashed=false"
+
+ try:
+ results = (
+ self.service.files()
+ .list(
+ q=query,
+ pageSize=100,
+ fields="nextPageToken, files(id, name, mimeType, modifiedTime, size)",
+ )
+ .execute()
+ )
+
+ return results.get("files", [])
+ except Exception as e:
+ logger.info(f"Error listing files: {str(e)}")
+ return []
+
+ def upload_yaml_file(self, file_path, filename, folder_id=None):
+ """YAML ํ์ผ ๊ฒฝ๋ก๋ฅผ ๋ฐ์์ ์
๋ก๋"""
+ try:
+ # ํ์ผ ๋ฉํ๋ฐ์ดํฐ ์ค์
+ file_metadata = {"name": filename, "mimeType": "application/x-yaml"}
+ if folder_id:
+ file_metadata["parents"] = [folder_id]
+
+ # ๋ฏธ๋์ด ๊ฐ์ฒด ์์ฑ
+ media = MediaFileUpload(file_path, mimetype="application/x-yaml", resumable=True)
+
+ # ํ์ผ ์
๋ก๋
+ file = self.service.files().create(body=file_metadata, media_body=media, fields="id, name").execute()
+
+ logger.debug(f"Successfully uploaded {filename} to Google Drive")
+ return file
+
+ except FileNotFoundError:
+ logger.error(f"File not found: {file_path}")
+ return None
+ except Exception as e:
+ logger.error(f"Error uploading YAML file: {str(e)}")
+ return None
+
+ def upload_dataframe(self, dataframe, filename, folder_id=None):
+ """Pandas DataFrame ์ง์ ์
๋ก๋"""
+ try:
+ # DataFrame์ CSV ์คํธ๋ฆผ์ผ๋ก ๋ณํ
+ buffer = io.StringIO()
+ dataframe.to_csv(buffer, index=False)
+ file_stream = io.BytesIO(buffer.getvalue().encode("utf-8"))
+
+ # ํ์ผ ๋ฉํ๋ฐ์ดํฐ ์ค์
+ file_metadata = {"name": filename, "mimeType": "text/csv"}
+ if folder_id:
+ file_metadata["parents"] = [folder_id]
+
+ # ๋ฏธ๋์ด ๊ฐ์ฒด ์์ฑ
+ media = MediaIoBaseUpload(file_stream, mimetype="text/csv", resumable=True)
+
+ # ํ์ผ ์
๋ก๋
+ file = self.service.files().create(body=file_metadata, media_body=media, fields="id, name").execute()
+ return file
+
+ except Exception as e:
+ logger.error(f"Error uploading DataFrame: {str(e)}")
+ return None
+
+ def upload_exp(self, user_name, output_path, config_path=None):
+ df = pd.read_csv(output_path)
+ df_basename = os.path.basename(output_path)
+
+ if config_path is None:
+ config_path = os.path.join(self.config_folder, "config.yaml")
+ config_basename = df_basename.replace("output.csv", "config.yaml")
+
+ # ์คํ์๋ช
์ผ๋ก ํด๋๋ช
์ฐพ๊ธฐ
+ folder_id = self.find_folder_id_by_name(user_name)
+ _ = self.upload_dataframe(df, df_basename, folder_id)
+ _ = self.upload_yaml_file(config_path, config_basename, folder_id)
+
+ gdrive_url = os.path.join("https://drive.google.com/drive/folders", folder_id)
+ logger.info(f"๊ตฌ๊ธ ๋๋ผ์ด๋ธ์ ์
๋ก๋ ๋์์ต๋๋ค: {gdrive_url}")
+
+
+if __name__ == "__main__":
+ os.chdir("..")
+ load_dotenv("../config/.env")
+ drive_manager = GoogleDriveManager()
+ # ํ์ผ ๋ชฉ๋ก ์กฐํ
+ files = drive_manager.list_folder_files()
+ for file in files:
+ logger.info(f"Name: {file['name']}, ID: {file['id']}")
diff --git a/code/utils/hf_manager.py b/code/utils/hf_manager.py
new file mode 100755
index 0000000..976798f
--- /dev/null
+++ b/code/utils/hf_manager.py
@@ -0,0 +1,81 @@
+import os
+
+from datasets import load_dataset
+from dotenv import load_dotenv
+from huggingface_hub import HfApi
+from loguru import logger
+from peft import AutoPeftModelForCausalLM
+from transformers import AutoTokenizer
+
+
+class HuggingFaceHubManager:
+ def __init__(self):
+ load_dotenv(os.path.join(os.path.dirname(__file__), "..", "..", "config", ".env"))
+ self.token = os.getenv("HF_TOKEN")
+ self.organization = os.getenv("HF_TEAM_NAME")
+ self.project_name = os.getenv("HF_PROJECT_NAME")
+
+ # ํ๊ฒฝ ๋ณ์ ๊ฒ์ฆ ์ถ๊ฐ
+ if not all([self.token, self.organization, self.project_name]):
+ raise ValueError("ํ์ ํ๊ฒฝ ๋ณ์๊ฐ ์ค์ ๋์ง ์์์ต๋๋ค.")
+
+ def upload_model(self, model_name, username, checkpoint_path):
+ repo_id = f"{model_name}-{username}"
+ try:
+ model = AutoPeftModelForCausalLM.from_pretrained(
+ checkpoint_path,
+ trust_remote_code=True,
+ device_map="auto",
+ )
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint_path, trust_remote_code=True)
+
+ model.push_to_hub(repo_id=repo_id, organization=self.organization, use_auth_token=self.token)
+ tokenizer.push_to_hub(repo_id=repo_id, organization=self.organization, use_auth_token=self.token)
+ logger.debug(f"your model pushed successfully in {repo_id}, hugging face")
+ except Exception as e:
+ logger.debug(f"An error occurred while uploading to Hugging Face: {e}")
+
+ def upload_dataset(self, file_name, private=True):
+ """
+ ํด๋ ๋ด ๋ฐ์ดํฐ ํ์ผ๋ค์ Hugging Face Hub์ ๋ฐ์ดํฐ์
์ผ๋ก ์
๋ก๋ํ๋ ํจ์.
+
+ Parameters:
+ - file_name (str): ์
๋ก๋ํ ๋ก์ปฌ ๋ฐ์ดํฐ ํ์ผ ์ด๋ฆ
+ - token (str): Hugging Face ์ก์ธ์ค ํ ํฐ. ์ฐ๊ธฐ ๊ถํ ํ์
+ - private (bool): True๋ฉด ๋น๊ณต๊ฐ, False๋ฉด ๊ณต๊ฐ ์ค์
+
+ Returns:
+ - None
+ """
+ api = HfApi()
+ repo_id = f"{self.organization}/{self.project_name}-{file_name}"
+
+ # ๋ฆฌํฌ์งํ ๋ฆฌ ์กด์ฌ ์ฌ๋ถ ํ์ธ
+ try:
+ api.repo_info(repo_id, repo_type="dataset", token=self.token)
+ logger.debug(f"'{repo_id}' ๋ฆฌํฌ์งํ ๋ฆฌ๊ฐ ์ด๋ฏธ ์กด์ฌํฉ๋๋ค. ๊ธฐ์กด ๋ฆฌํฌ์งํ ๋ฆฌ์ ๋ฐ์ดํฐ์
์ ์
๋ก๋ํฉ๋๋ค.")
+ except Exception:
+ # ๋ฆฌํฌ์งํ ๋ฆฌ๊ฐ ์์ผ๋ฉด ์์ฑ
+ logger.debug(f"'{repo_id}' ๋ฆฌํฌ์งํ ๋ฆฌ๊ฐ ์กด์ฌํ์ง ์์ต๋๋ค. ์๋ก ์์ฑํ ํ ์
๋ก๋ํฉ๋๋ค.")
+ api.create_repo(repo_id=repo_id, repo_type="dataset", private=private, token=self.token)
+
+ # ํ์ผ ๊ฒฝ๋ก ์ค์
+ file_path = os.path.join("..", "data", f"{file_name}.csv")
+ if not os.path.exists(file_path):
+ logger.debug(f"ํ์ผ '{file_path}'์ด ์กด์ฌํ์ง ์์ต๋๋ค.")
+ return
+
+ # ๋ฐ์ดํฐ์
๋ก๋ ๋ฐ ์
๋ก๋
+ dataset = load_dataset("csv", data_files={"train": file_path})
+ dataset.push_to_hub(repo_id, token=self.token)
+ logger.debug(f"๋ฐ์ดํฐ์
์ด '{repo_id}'์ ์
๋ก๋๋์์ต๋๋ค.")
+
+
+if __name__ == "__main__":
+ os.chdir("..")
+ load_dotenv("../config/.env")
+ logger.debug(f'{os.getenv("UPLOAD_MODEL_NAME")}, {os.getenv("USERNAME")}, {os.getenv("CHECKPOINT_PATH")}')
+
+ hf_manager = HuggingFaceHubManager()
+ hf_manager.upload_model(os.getenv("UPLOAD_MODEL_NAME"), os.getenv("USERNAME"), os.getenv("CHECKPOINT_PATH"))
+ # hf_manager.upload_dataset(args.dataname, private=True)
diff --git a/config/elastic_setting.json b/config/elastic_setting.json
new file mode 100644
index 0000000..4e49153
--- /dev/null
+++ b/config/elastic_setting.json
@@ -0,0 +1,33 @@
+{
+ "settings": {
+ "analysis": {
+ "filter": {
+ "my_shingle": {
+ "type": "shingle"
+ }
+ },
+ "analyzer": {
+ "my_analyzer": {
+ "type": "custom",
+ "tokenizer": "nori_tokenizer",
+ "decompound_mode": "mixed",
+ "filter": ["my_shingle"]
+ }
+ },
+ "similarity": {
+ "my_similarity": {
+ "type": "BM25"
+ }
+ }
+ }
+ },
+
+ "mappings": {
+ "properties": {
+ "document_text": {
+ "type": "text",
+ "analyzer": "my_analyzer"
+ }
+ }
+ }
+}
diff --git a/config/sample/config.yaml b/config/sample/config.yaml
new file mode 100755
index 0000000..df99b8a
--- /dev/null
+++ b/config/sample/config.yaml
@@ -0,0 +1,95 @@
+data:
+ train_path: "../data/train.csv"
+ test_path: "../data/test.csv"
+ processed_train_path: "../data/train_500_60to1_es.csv" # ๋ฏธ๋ฆฌ ์ ์ฒ๋ฆฌํ ๋ฐ์ดํฐ ์ฌ์ฉ: ๋น์๋๋ฉด ๋์ํ์ง ์์
+ processed_test_path: "../data/test_500_60to1_es.csv" # ๋ฏธ๋ฆฌ ์ ์ฒ๋ฆฌํ ๋ฐ์ดํฐ ์ฌ์ฉ: ๋น์๋๋ฉด ๋์ํ์ง ์์
+ max_seq_length: 2048
+ test_size: 0.1
+ retriever:
+ retriever_type: "Elasticsearch" # Elasticsearch
+ query_type: "p" # retrieve ์ฟผ๋ฆฌ ํ์
: pqc, pq, pc, p
+ query_max_length: 500 # retrieve ๋์์ด ๋ ์ฟผ๋ฆฌ์ ์ต๋ ๊ธธ์ด: 250-500 ๊ถ์ฅ
+ result_max_length: 1500 # retrieve ๊ฒฐ๊ณผ ๋ฌธ์์ ์ต๋ ๊ธธ์ด: 1500-2000 ๊ถ์ฅ
+ top_k: 60 # 60~80
+ rerank_k: 1 # 0 ์ดํ๋ reranker ๋์ํ์ง ์์
+ threshold: 0.2 # 0.2 ~ 0.5
+ index_name: "two-wiki-index" # wiki-index, two-wiki-index, aihub-news-index
+ prompt:
+ start: "์ง๋ฌธ:\n {paragraph}\n\n์ง๋ฌธ:\n {question}\n\n์ ํ์ง:\n {choices}\n\n"
+ start_with_plus: "์ง๋ฌธ:\n {paragraph}\n\n์ง๋ฌธ:\n {question}\n\n<๋ณด๊ธฐ>:\n {question_plus}\n\n์ ํ์ง:\n {choices}\n\n"
+ mid: ""
+ mid_with_document: "ํํธ:\n {document}\n\n"
+ end: "1, 2, 3, 4, 5 ์ค์ ํ๋๋ฅผ ์ ๋ต์ผ๋ก ๊ณ ๋ฅด์ธ์.\n์ ๋ต:"
+ end_gen_cot: "1, 2, 3, 4, 5 ์ค์ ํ๋๋ฅผ ์ ๋ต์ผ๋ก ๊ณ ๋ฅด๊ธฐ ์ํ ๊ทผ๊ฑฐ๋ฅผ ์ฐจ๊ทผ์ฐจ๊ทผ ์๊ฐํด๋ณด์ธ์.\n๊ทผ๊ฑฐ:"
+ end_with_cot: "1, 2, 3, 4, 5 ์ค์ ํ๋๋ฅผ ์ ๋ต์ผ๋ก ๊ณ ๋ฅด์ธ์.\n{cot}\n์ ๋ต:"
+
+model:
+ base_model: "beomi/gemma-ko-2b"
+ model:
+ torch_dtype: "float16"
+ low_cpu_mem_usage: true
+ use_cache: false # gradient_checkpointing์ด true๋ฉด false์ฌ์ผํจ
+ quantization: "" # BitsAndBytes, auto
+ bits: 8 # 8 or 4
+ use_double_quant: false
+ tokenizer:
+ padding_side: "right"
+ chat_template: "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ system_message }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ 'user\n' + content + '\nmodel\n' }}{% elif message['role'] == 'assistant' %}{{ content + '\n' }}{% endif %}{% endfor %}"
+
+training:
+ response_template: "model"
+ lora:
+ r: 6
+ lora_alpha: 8
+ lora_dropout: 0.05
+ target_modules: ["q_proj", "k_proj"]
+ bias: "none"
+ task_type: "CAUSAL_LM"
+
+ params:
+ do_train: true
+ do_eval: true
+ lr_scheduler_type: "cosine"
+ max_seq_length: 2048
+ per_device_train_batch_size: 1
+ per_device_eval_batch_size: 1
+ gradient_accumulation_steps: 1
+ gradient_checkpointing: true
+ max_grad_norm: 0.3
+ num_train_epochs: 3
+ learning_rate: 2.0e-05
+ weight_decay: 0.01
+ optim: "adamw_torch" # ์์ํ: adamw_bnb_8bit
+ logging_strategy: "steps"
+ save_strategy: "steps"
+ eval_strategy: "steps"
+ logging_steps: 300
+ save_steps: 600
+ eval_steps: 300
+ save_total_limit: 4
+ save_only_model: true
+ load_best_model_at_end: true # early_stop์ ์ํด ํ์
+ report_to: "wandb"
+ run_name: "../outputs" # wandb ์ธํ
์ด ์กด์ฌํ๋ค๋ฉด ๋์ ์ผ๋ก ์์ฑ๋ฉ๋๋ค.
+ output_dir: "../outputs"
+ overwrite_output_dir: true
+ metric_for_best_model: "accuracy" # early_stop ๊ธฐ์ค
+ early_stop_patience: 2
+ early_stop_threshold: 0
+
+
+inference:
+ do_test: true
+ output_path: "../outputs/"
+
+log:
+ file: "../log/file.log"
+ level: "INFO"
+
+wandb:
+ project: generation_for_nlp
+ entity: hidong1015-nlp04
+
+exp:
+ # ์คํ์ [sujin, seongmin, sungjae, gayeon, yeseo, minseo]
+ username: fubao
diff --git a/config/sample/env-sample.txt b/config/sample/env-sample.txt
new file mode 100644
index 0000000..dc8689d
--- /dev/null
+++ b/config/sample/env-sample.txt
@@ -0,0 +1,17 @@
+# .env๋ก ๋ณํํ์ฌ ์ฌ์ฉ
+HF_TOKEN=""
+HF_TEAM_NAME = "paper-company"
+HF_PROJECT_NAME = "KSAT"
+
+GDRIVE_TOKEN="../config/token.json"
+GDRIVE_CREDENTIALS="../config/credentials.json"
+GDRIVE_FOLDER_ID =""
+GDRIVE_CREATE_TOKEN= "false"
+
+UPLOAD_MODEL_NAME = "1115-fubao-exaone3.0-base-v1" # ๋ ์ง-์ด๋ฆ-๋ฒ ์ด์ค๋ชจ๋ธ-์ฌ์ฉ๋ฐ์ดํฐ์
-๋ฒ์
+USERNAME = "fubao"
+CHECKPOINT_PATH="../outputs/checkpoint-9999"
+
+ELASTICSEARCH_URL="http://localhost:9200"
+ELASTICSEARCH_ID=""
+ELASTICSEARCH_PW=""
diff --git a/data_aug/add_CoT.py b/data_aug/add_CoT.py
new file mode 100644
index 0000000..417dd85
--- /dev/null
+++ b/data_aug/add_CoT.py
@@ -0,0 +1,57 @@
+from ast import literal_eval
+
+import dspy
+import pandas as pd
+from tqdm import tqdm
+
+
+def process_csv(file_path, output_path, lm_api_key):
+ lm = dspy.LM("openai/gpt-4o", api_key=lm_api_key)
+ dspy.configure(lm=lm)
+
+ df = pd.read_csv(file_path)
+
+ records = []
+ for _, row in df.iterrows():
+ problems = literal_eval(row["problems"])
+ record = {
+ "id": row["id"],
+ "paragraph": row["paragraph"],
+ "question": problems["question"],
+ "choices": problems["choices"],
+ "answer": problems.get("answer", None),
+ "question_plus": problems.get("question_plus", None),
+ }
+ records.append(record)
+
+ df = pd.DataFrame(records)
+
+ df["steps"] = None
+ data_list = []
+
+ for index, row in tqdm(df.iterrows(), total=len(df)):
+ input_data = {
+ "paragraph": row["paragraph"],
+ "question": row["question"],
+ "choices": row["choices"],
+ "answer": row["answer"],
+ }
+ classify = dspy.ChainOfThought("paragraph: str, question: str, choices: list, answer: str -> steps: list", n=1)
+
+ input_data["question"] = f"{row['question']} ๋จ๊ณ๋ณ ์ค๋ช
(CoT)์ ์ฌ์ฉํ์ฌ ์ฌ๋ฐ๋ฅธ ๋ต์ ๋์ถํ์ธ์."
+ response = classify(**input_data)
+ print("response.completions", response.completions)
+ data_list.append(response.completions)
+
+ df["steps"] = data_list
+
+ df.to_csv(output_path, index=False, encoding="utf-8-sig")
+ print(f"Updated CSV file saved to: {output_path}")
+
+
+if __name__ == "__main__":
+ input_file_path = ""
+ output_file_path = ""
+ api_key = ""
+
+ process_csv(input_file_path, output_file_path, api_key)
diff --git a/data_aug/aug_philo.py b/data_aug/aug_philo.py
new file mode 100644
index 0000000..67a210a
--- /dev/null
+++ b/data_aug/aug_philo.py
@@ -0,0 +1,64 @@
+import os
+import warnings
+
+from langchain.prompts import PromptTemplate
+from langchain_openai import ChatOpenAI
+import pandas as pd
+
+
+def parse_output(output):
+ lines = output.split("\n")
+ data_list = []
+ data = {"์ง๋ฌธ": "", "๋ฌธ์ ": "", "๋ณด๊ธฐ": "", "์ ๋ต": "", "ํด์ค": ""}
+ current_key = None
+
+ for line in lines:
+ line = line.strip()
+ if line.startswith("์ง๋ฌธ:"):
+ if data["์ง๋ฌธ"]:
+ data_list.append(data.copy())
+ data = {"์ง๋ฌธ": "", "๋ฌธ์ ": "", "๋ณด๊ธฐ": "", "์ ๋ต": "", "ํด์ค": ""}
+ current_key = "์ง๋ฌธ"
+ data["์ง๋ฌธ"] = line.replace("์ง๋ฌธ:", "").strip()
+ elif line.startswith("๋ฌธ์ :"):
+ current_key = "๋ฌธ์ "
+ data["๋ฌธ์ "] = line.replace("๋ฌธ์ :", "").strip()
+ elif line.startswith("๋ณด๊ธฐ:"):
+ current_key = "๋ณด๊ธฐ"
+ data["๋ณด๊ธฐ"] = line.replace("๋ณด๊ธฐ:", "").strip()
+ elif line.startswith("์ ๋ต:"):
+ current_key = "์ ๋ต"
+ data["์ ๋ต"] = line.replace("์ ๋ต:", "").strip()
+ elif line.startswith("ํด์ค:"):
+ current_key = "ํด์ค"
+ data["ํด์ค"] = line.replace("ํด์ค:", "").strip()
+ elif current_key:
+ data[current_key] += " " + line
+
+ if data["์ง๋ฌธ"]:
+ data_list.append(data)
+
+ return data_list
+
+
+if __name__ == "__main__":
+ warnings.filterwarnings("ignore")
+ os.environ["OPENAI_API_KEY"] = ""
+
+ # ์ฌ์ฉํ LLM ๋ชจ๋ธ ์ค์
+ llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0.9)
+
+ # ํ๋กฌํํธ ํ
ํ๋ฆฟ ์ค์
+ prompt_template = """"""
+
+ prompt = PromptTemplate(input_variables=[], template=prompt_template)
+
+ response = llm.invoke(prompt.format())
+ output = response.content
+
+ parsed_data_list = parse_output(output)
+
+ df = pd.DataFrame(parsed_data_list)
+
+ csv_filename = "philosophy_questions.csv"
+ df.to_csv(csv_filename, index=False, encoding="utf-8-sig")
diff --git a/data_process/crawling_gichulpass.py b/data_process/crawling_gichulpass.py
new file mode 100644
index 0000000..3b846b1
--- /dev/null
+++ b/data_process/crawling_gichulpass.py
@@ -0,0 +1,159 @@
+import json
+import re
+
+from bs4 import BeautifulSoup
+from datasets import load_dataset
+import pandas as pd
+import requests
+
+
+def answer_symbol_to_int(symbol: str) -> int:
+ # ์ ๊ทํํ์์ผ๋ก ํน์๋ฌธ์๋ง ์ถ์ถ
+ special_chars = re.findall(r"[โ โกโขโฃโค]", symbol)
+ cleaned_symbol = special_chars[0] if special_chars else symbol
+
+ answer_map = {"โ ": 1, "โก": 2, "โข": 3, "โฃ": 4, "โค": 5}
+ return answer_map.get(cleaned_symbol, -1)
+
+
+def extract_question_data(soup, with_table=False):
+ questions_with_table = []
+ questions_without_table = []
+
+ for item in soup.select("#examList > li"):
+ # ๋ฌธ์ ๋ฒํธ์ ์ง๋ฌธ
+ question_element = item.select_one(".pr_problem")
+ question_text = question_element.get_text(strip=True)
+
+ # ํ
์ด๋ธ ์กด์ฌ ์ฌ๋ถ ํ์ธ
+ has_table = False
+
+ # ๋ฌธ์ ์ ํ
์ด๋ธ์ด ์๋ ๊ฒฝ์ฐ
+ question_table = question_element.find("table")
+ if question_table:
+ has_table = True
+ question_text = {"text": question_text, "table_html": str(question_table)}
+
+ # ๋ฌธ์ ์ค๋ช
(์๋ ๊ฒฝ์ฐ)
+ example = item.select_one(".exampleCon")
+ example_text = example.get_text(strip=True) if example else ""
+
+ # ์์์ ํ
์ด๋ธ์ด ์๋ ๊ฒฝ์ฐ
+ if example:
+ example_table = example.find("table")
+ if example_table:
+ has_table = True
+ example_text = {"text": example_text, "table_html": str(example_table)}
+
+ # '๊ทธ๋ฆผ' ํฌํจ๋ ๋ฌธ์ ๋ ๊ฑด๋๋ฐ๊ธฐ
+ if isinstance(question_text, str) and "๊ทธ๋ฆผ" in question_text:
+ continue
+ if isinstance(example_text, str) and "๊ทธ๋ฆผ" in example_text:
+ continue
+
+ # ์ ํ์ง๋ฅผ ๋ฆฌ์คํธ๋ก ์ ์ฅ
+ choices = []
+ for choice in item.select(".questionCon li label"):
+ choices.append(choice.get_text(strip=True))
+
+ # ์ ๋ต ๋ฒํธ ์ถ์ถ
+ answer_element = item.select_one(".answer_num")
+ answer_text = answer_element.get_text(strip=True).strip() if answer_element else ""
+ answer = answer_symbol_to_int(answer_text)
+
+ # ์ ๋ต ์ค๋ช
+ explanation = item.select_one(".answer_explan")
+ explanation_text = explanation.get_text(strip=True) if explanation else ""
+
+ # ๋ฐ์ดํฐ ์ ์ฅ
+ question_data = {
+ "question": question_text,
+ "paragraph": example_text,
+ "choices": json.dumps(choices, ensure_ascii=False),
+ "answer": answer,
+ "answer_explanation": explanation_text,
+ }
+
+ # ํ
์ด๋ธ ์ ๋ฌด์ ๋ฐ๋ผ ๋ค๋ฅธ ๋ฆฌ์คํธ์ ์ ์ฅ
+ if has_table:
+ questions_with_table.append(question_data)
+ else:
+ questions_without_table.append(question_data)
+
+ return pd.DataFrame(questions_without_table) if not with_table else pd.DataFrame(questions_with_table)
+
+
+def crawl_and_save(subject_code):
+ headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
+
+ url = "https://gichulpass.com/bbs/board.php"
+
+ # ํ๊ณํ ๋ฌธ์
+ if subject_code == 20:
+ wr_ids = [903, 27, 889, 887, 885, 884, 880]
+
+ # ํ๋ฒ ๋ฌธ์
+ if subject_code == 26:
+ wr_ids = (
+ [1136, 1063, 963, 953, 875, 849, 543, 537, 526, 515, 510, 500, 542, 536, 525, 514, 509]
+ + [499, 541, 535, 524, 513, 508, 498, 540, 534, 523, 512, 507, 497, 539, 533, 522, 511, 506]
+ + [496, 538, 532, 521, 505, 495, 531, 520, 504, 494, 530, 519, 503, 493, 529, 518, 502, 492]
+ + [528, 517, 501, 491, 527, 516, 490]
+ )
+
+ # ํ๊ตญ์ฌ ๋ฌธ์
+ if subject_code == 34:
+ # 9๊ธ๋ง ํํฐ๋ง
+ wr_ids = (
+ list(range(808, 814))
+ + list(range(224, 243))
+ + list(range(264, 269))
+ + list(range(288, 298))
+ + list(range(305, 326))
+ + [16, 841, 870, 897, 897, 912, 962, 1012, 1026, 1053, 1167, 1172, 1173, 1176, 1177, 1184]
+ + [1205, 1240, 1241, 1242, 1243, 1244, 1245, 1246, 1247, 1257, 1262, 1357, 1367, 1368, 1404, 1405]
+ )
+
+ # ์ฌํ ๋ฌธ์
+ if subject_code == 35:
+ wr_ids = list(range(808, 840)) + [865, 894, 908, 1010, 1027, 1050]
+
+ dfs = []
+ for wr_id in wr_ids:
+ params = {"bo_table": "exam", "wr_id": wr_id, "subject": subject_code}
+ response = requests.get(url, params=params, headers=headers)
+ soup = BeautifulSoup(response.text, "html.parser")
+ df = extract_question_data(soup)
+ dfs.append(df)
+
+ concated_df = pd.concat(dfs, axis=0)
+ len_df = len(concated_df)
+ concated_df.to_csv(f"gichulpass_{subject_code}_{len_df}_raw.csv", index=False, encoding="utf-8")
+
+
+def check_KMMLU(input_file, output_file):
+ df = pd.read_csv(input_file, encoding="utf-8")
+ ds = load_dataset("HAERAE-HUB/KMMLU", "Korean-History")
+ # df์ None๊ฐ์ ๋น ๋ฌธ์์ด๋ก ๋ณํํ๊ณ ๋ชจ๋ ๊ฐ์ ๋ฌธ์์ด๋ก ๋ณํ
+ df = df.fillna("")
+ df = df.astype(str)
+
+ # ๋ชจ๋ split์ question์ ํ๋๋ก ํฉ์น๊ธฐ
+ questions = pd.concat([pd.DataFrame(ds["train"]), pd.DataFrame(ds["dev"]), pd.DataFrame(ds["test"])])
+
+ # ํฌํจ ์ฌ๋ถ๋ฅผ ํ์ธํ๋ ํจ์
+ def check_inclusion(column):
+ return column.apply(lambda x: any(x in question for question in questions["question"]))
+
+ # paragraph์ question ๊ฐ๊ฐ ํ์ธ ํ ๋ ์ค ํ๋๋ผ๋ True๋ฉด True
+ df["include"] = check_inclusion(df["paragraph"]) | check_inclusion(df["question"])
+
+ new_df = df[not df["include"]]
+ new_df.to_csv(output_file, index=False)
+
+
+if __name__ == "__main__":
+ crawl_and_save(subject_code=20)
+ crawl_and_save(subject_code=26)
+ crawl_and_save(subject_code=34)
+ crawl_and_save(subject_code=35)
diff --git a/data_process/external_musr.py b/data_process/external_musr.py
new file mode 100644
index 0000000..8167799
--- /dev/null
+++ b/data_process/external_musr.py
@@ -0,0 +1,63 @@
+from data_process.process_google_translate import TranslationCache, translate_column, translate_list_column
+from datasets import load_dataset
+from loguru import logger
+import pandas as pd
+from tqdm import tqdm
+
+
+def dataset_to_pd(data_name):
+ data = load_dataset(data_name)
+ dfs = [
+ pd.DataFrame(data["murder_mysteries"]),
+ pd.DataFrame(data["object_placements"]),
+ pd.DataFrame(data["team_allocation"]),
+ ]
+ return pd.concat(dfs, axis=0)
+
+
+def process_data(df):
+ return pd.DataFrame(
+ {
+ "paragraph": df["narrative"],
+ "question": df["question"],
+ "choices": df["choices"],
+ "answer": df["answer_index"].apply(lambda x: x + 1),
+ }
+ )
+
+
+def process_external_datasets(dataset_name, output_filename):
+ df = dataset_to_pd(dataset_name)
+ df = process_data(df)
+
+ # ๊ฐํ๋ฌธ์๋ฅผ \n ๋ฌธ์์ด๋ก ๋ณํ
+ for col in df.columns:
+ if df[col].dtype == "object": # ๋ฌธ์์ด ์ปฌ๋ผ์ ๋ํด์๋ง ์ฒ๋ฆฌ
+ df[col] = df[col].str.replace("\n", "\\n")
+
+ df.to_csv(output_filename, index=False)
+
+
+def translate_df(input_filename, output_filename):
+ df = pd.read_csv(input_filename)
+
+ logger.info("๋จ๋ฝ ๋ฒ์ญ ์ค...")
+ with TranslationCache("paragraph_cache.json") as paragraph_cache:
+ df["paragraph"] = translate_column(df["paragraph"], paragraph_cache)
+
+ logger.info("์ง๋ฌธ ๋ฒ์ญ ์ค...")
+ with TranslationCache("question_cache.json") as question_cache:
+ df["question"] = translate_column(df["question"], question_cache)
+
+ logger.info("์ ํ์ง ๋ฒ์ญ ์ค...")
+ with TranslationCache("choices_cache.json") as choices_cache:
+ tqdm.pandas(desc="์ ํ์ง ๋ฒ์ญ ์ค")
+ df["choices"] = df["choices"].apply(lambda x: translate_list_column(x, choices_cache))
+
+ df.to_csv(output_filename, index=False)
+
+
+if __name__ == "__main__":
+ dataset_name = "TAUR-Lab/MuSR"
+ process_external_datasets(dataset_name, "MuSR_en_raw.csv")
+ translate_df("MuSR_en_raw.csv", "MuSR_ko_raw.csv")
diff --git a/data_process/external_race.py b/data_process/external_race.py
new file mode 100644
index 0000000..cd7f01c
--- /dev/null
+++ b/data_process/external_race.py
@@ -0,0 +1,73 @@
+from datasets import load_dataset
+import pandas as pd
+
+
+def dataset_to_pd(data_name):
+ """์ฃผ์ด์ง ๋ฐ์ดํฐ์
์ด๋ฆ์ผ๋ก๋ถํฐ DataFrame์ ์์ฑํฉ๋๋ค."""
+ dataset = load_dataset(data_name, "high", split="validation") # 'train', 'validation', 'test' ์ค ์ ํ
+ return pd.DataFrame(dataset)
+
+
+def process_query_data(df):
+ """DataFrame์ ์ฒ๋ฆฌํ์ฌ ํ์ํ ํ์์ผ๋ก ๋ณํํฉ๋๋ค."""
+
+ # answer๋ฅผ A, B, C, D ํ์์์ 1, 2, 3, 4 ํ์์ผ๋ก ๋ณํ
+ def convert_answer(answer):
+ if answer == "A":
+ return 1
+ elif answer == "B":
+ return 2
+ elif answer == "C":
+ return 3
+ elif answer == "D":
+ return 4
+ else:
+ return None # ์์์น ๋ชปํ ๊ฐ์ ๋ํ ์ฒ๋ฆฌ
+
+ df["article"] = (
+ df["article"]
+ .str.replace(",", "")
+ .str.replace('""', "")
+ .str.replace(r"\. ", ".")
+ .str.replace(r'\." ', '."')
+ .str.replace(r"([.!?]) ", r"\1")
+ )
+ # ๋ฌธ์ ๋ฅผ ๋ฌธ์์ด๋ก ๋ณํํ์ฌ DataFrame ์์ฑ
+ problems = df.apply(
+ lambda row: {"question": row["question"], "choices": row["options"], "answer": convert_answer(row["answer"])},
+ axis=1,
+ )
+
+ return pd.DataFrame(
+ {
+ "id": df["example_id"], # ์์ ์ ID
+ "paragraph": df["article"], # ์ ๋ฆฌ๋ article ์ฌ์ฉ
+ "problems": problems.apply(str), # ๋ฌธ์ ๋ฅผ ๋ฌธ์์ด๋ก ๋ณํ
+ "question_plus": None, # question_plus๊ฐ ์๋ณธ ๋ฐ์ดํฐ์ ์๋ค๊ณ ๊ฐ์
+ }
+ )
+
+
+if __name__ == "__main__":
+ dataset_name = "ehovy/race" # ์ฌ์ฉํ ๋ฐ์ดํฐ์
์ด๋ฆ
+ df = dataset_to_pd(dataset_name) # ๋ฐ์ดํฐ์
์ DataFrame์ผ๋ก ๋ณํ
+ processed_df = process_query_data(df) # ๋ฐ์ดํฐ ์ฒ๋ฆฌ
+
+ # ๊ฒฐ๊ณผ๋ฅผ CSV ํ์ผ๋ก ์ ์ฅ (๋ฐ์ดํ ์ฒ๋ฆฌ)
+ processed_df.to_csv("processed_race_dataset.csv", index=False)
+
+ import re
+
+ import pandas as pd
+
+ # CSV ํ์ผ ์ฝ๊ธฐ
+ df = pd.read_csv("processed_race_dataset.csv")
+
+ # paragraph ์ด๋ง ์์
+ df["paragraph"] = df["paragraph"].apply(lambda x: re.sub(r"\n", "", x))
+ df["paragraph"] = df["paragraph"].apply(lambda x: re.sub(r"\. ", ".", x))
+ df["paragraph"] = df["paragraph"].apply(lambda x: re.sub(r'\." ', '."', x))
+ df["paragraph"] = df["paragraph"].apply(lambda x: re.sub(r"([.!?]) ", r"\1", x))
+
+ # ์์ ๋ DataFrame์ ๋ค์ CSV๋ก ์ ์ฅ
+ df.to_csv("modified_file.csv", index=False)
diff --git a/data_process/external_sat_gaokao.py b/data_process/external_sat_gaokao.py
new file mode 100644
index 0000000..904b285
--- /dev/null
+++ b/data_process/external_sat_gaokao.py
@@ -0,0 +1,86 @@
+from datasets import load_dataset
+import pandas as pd
+
+
+def dataset_to_pd(data_name):
+ data = load_dataset(data_name)
+ return pd.DataFrame(data["test"])
+
+
+def process_query_data(input_df):
+ def _split_query_data(df):
+ paragraphs = []
+ questions = []
+
+ for index, row in df.iterrows():
+ text = row["query"]
+ # Paragraph์ ๋๋จธ์ง ํ
์คํธ ๋ถ๋ฆฌ
+ paragraph, rest = text.split("Q:", 1)
+ # Question๊ณผ Answer Choices ๋ถ๋ฆฌ
+ question, choices = rest.split("Answer Choices: ", 1)
+
+ paragraphs.append(paragraph.strip())
+ questions.append(question.strip())
+
+ return pd.DataFrame(
+ {
+ "paragraph": paragraphs,
+ "question": questions,
+ "choices": df["choices"],
+ "answer": df["gold"].apply(lambda x: x[0] + 1), # gold ๋ฐฐ์ด์ ํ๊ณ +1
+ }
+ )
+
+ def _split_answer_choices(df):
+ def process_choices(choices_list): # ๋ฆฌ์คํธ ํํ๋ก ์
๋ ฅ ๋ฐ์
+ new_choices = []
+ for choice in choices_list:
+ new_choice = (
+ choice.replace("(A)", "")
+ .replace("(B)", "")
+ .replace("(C)", "")
+ .replace("(D)", "")
+ .replace("(E)", "")
+ )
+ new_choices.append(new_choice.strip())
+ return new_choices
+
+ df["choices"] = df["choices"].apply(process_choices)
+ return df
+
+ split_df = _split_query_data(input_df)
+ final_df = _split_answer_choices(split_df)
+ return final_df
+
+
+def process_and_concat_external_datasets(dataset_names, output_filename):
+ dfs = []
+ for dataset_name in dataset_names:
+ df = dataset_to_pd(dataset_name)
+ df = process_query_data(df)
+ dfs.append(df)
+
+ concated_df = pd.concat(dfs, axis=0)
+ concated_df.to_csv(output_filename, index=False)
+
+
+def clean_string(text):
+ # ๋ฌธ์์ด ๋ด๋ถ์ ๋ชจ๋ ํฐ๋ฐ์ดํ๋ฅผ ์์๋ฐ์ดํ๋ก ๋ณํ
+ text = text.replace('"', "'")
+ # ์ฐ์๋ ํฐ๋ฐ์ดํ๋ฅผ ํ๋๋ก ๋ณํ
+ while "''" in text:
+ text = text.replace("''", "'")
+ text = text.strip() # ์๋ค ๊ณต๋ฐฑ ์ ๊ฑฐ
+ return text
+
+
+if __name__ == "__main__":
+ dataset_names = [
+ "dmayhem93/agieval-sat-en",
+ "dmayhem93/agieval-logiqa-en",
+ "dmayhem93/agieval-lsat-rc",
+ "dmayhem93/agieval-lsat-lr",
+ "dmayhem93/agieval-lsat-ar",
+ "dmayhem93/agieval-gaokao-english",
+ ]
+ process_and_concat_external_datasets(dataset_names, "sat_gaokao_en_raw.csv")
diff --git a/data_process/pdf_to_txt.py b/data_process/pdf_to_txt.py
new file mode 100644
index 0000000..61889a7
--- /dev/null
+++ b/data_process/pdf_to_txt.py
@@ -0,0 +1,29 @@
+import os
+import re
+
+from pdfminer.high_level import extract_text
+
+
+def split_text_by_keyword(text, keyword):
+ sections = re.split(rf"{keyword}", text)
+ sections = [section.strip() + keyword for section in sections[:-1]] + [sections[-1].strip()]
+ return sections
+
+
+def save_sections_to_files(sections, output_dir="sections"):
+ os.makedirs(output_dir, exist_ok=True)
+ for i, section in enumerate(sections):
+ file_name = os.path.join(output_dir, f"section_{i+1}.txt")
+ with open(file_name, "w", encoding="utf-8") as f:
+ f.write(section)
+ print(f"Sections saved to '{output_dir}' directory.")
+
+
+if __name__ == "__main__":
+ pdf_file_path = "./data/test/2025.pdf"
+ output_dir = "./data/test/sections"
+ keyword = "๋ตํ์์ค"
+
+ text = extract_text(pdf_file_path)
+ split_text = split_text_by_keyword(text, keyword)
+ save_sections_to_files(split_text, output_dir=output_dir)
diff --git a/data_process/process_balance_choices.py b/data_process/process_balance_choices.py
new file mode 100644
index 0000000..4f1fa39
--- /dev/null
+++ b/data_process/process_balance_choices.py
@@ -0,0 +1,60 @@
+from ast import literal_eval
+import random
+
+import pandas as pd
+
+
+def balance_choices_dataset(file_path):
+ # CSV ํ์ผ ์ฝ๊ธฐ
+ df = pd.read_csv(file_path)
+ df["problems"] = df["problems"].apply(literal_eval)
+
+ # ์ ํ์ง์ ๋ต ๊ต์ฒด ํจ์
+ def swap_choices_and_answer(df):
+ for index, row in df.iterrows():
+ problems = row["problems"]
+ choices = problems["choices"]
+ answer = problems["answer"]
+ # ์ ํ์ง ๋๋ค ์๊ธฐ
+ shuffled_choices = choices[:]
+ random.shuffle(shuffled_choices)
+ # ์๋ก์ด ๋ต ์ธ๋ฑ์ค ๊ณ์ฐ
+ new_answer = shuffled_choices.index(choices[answer - 1])
+ # ๊ต์ฒด๋ ์ ํ์ง์ ๋ต์ผ๋ก ์
๋ฐ์ดํธ
+ problems["choices"] = shuffled_choices
+ problems["answer"] = new_answer + 1
+ df.at[index, "problems"] = problems
+ return df
+
+ # ์ ํ์ง์ ๋ต ๊ต์ฒด ์ ์ฉ
+ return swap_choices_and_answer(df)
+
+
+def answer_counts(file_path):
+ df = pd.read_csv(file_path)
+ df["problems"] = df["problems"].apply(literal_eval)
+
+ records = []
+ for idx, row in df.iterrows():
+ problems = row["problems"]
+ record = {
+ "id": row["id"],
+ "paragraph": row["paragraph"],
+ "question": problems["question"],
+ "choices": problems["choices"],
+ "answer": problems.get("answer", None),
+ "question_plus": problems.get("question_plus", None),
+ }
+ records.append(record)
+
+ processed_df = pd.DataFrame(records)
+ print(len(processed_df))
+
+ print(processed_df["choices"].apply(len).value_counts())
+ print(processed_df["answer"].value_counts())
+
+
+if __name__ == "__main__":
+ train_balanced = balance_choices_dataset("../data/train.csv")
+ train_balanced.to_csv("../data/train_balanced.csv", index=False)
+ answer_counts("../data/train_balanced.csv")
diff --git a/data_process/process_formatting.py b/data_process/process_formatting.py
new file mode 100644
index 0000000..4dd9f87
--- /dev/null
+++ b/data_process/process_formatting.py
@@ -0,0 +1,29 @@
+import pandas as pd
+
+
+def formatting(suffix, input_filename, output_filename):
+ # CSV ํ์ผ ์ฝ๊ธฐ
+ df = pd.read_csv(input_filename, encoding="utf-8")
+
+ # ์๋ก์ด ํ์์ผ๋ก ๋ณํ
+ new_records = []
+ for idx, row in df.iterrows():
+ new_record = {
+ "id": f"external-data-{suffix}{idx + 1}",
+ "paragraph": "" if pd.isna(row["paragraph"]) else str(row["paragraph"]),
+ "problems": {"question": row["question"].strip(), "choices": eval(row["choices"]), "answer": row["answer"]},
+ }
+ new_records.append(new_record)
+
+ # ์๋ก์ด DataFrame ์์ฑ
+ new_df = pd.DataFrame(new_records)
+ new_df.to_csv(output_filename, index=False)
+
+
+if __name__ == "__main__":
+ formatting("gichulpass20-", "gichulpass_20_107_raw.csv", "gichulpass_20_107.csv")
+ formatting("gichulpass26-", "gichulpass_26_1319_raw.csv", "gichulpass_24_1319.csv")
+ formatting("gichulpass34-", "gichulpass_34_1352_raw.csv", "gichulpass_34_1352.csv")
+ formatting("gichulpass35-", "gichulpass_35_568_raw.csv", "gichulpass_35_568.csv")
+ formatting("SAT", "sat_gaokao_ko_raw.csv", "sat_gaokao_ko.csv")
+ formatting("MuSR", "MuSR_ko_raw.csv", "MuSR_ko.csv")
diff --git a/data_process/process_google_translate.py b/data_process/process_google_translate.py
new file mode 100644
index 0000000..cf62868
--- /dev/null
+++ b/data_process/process_google_translate.py
@@ -0,0 +1,80 @@
+from ast import literal_eval
+import json
+import os
+import time
+
+from googletrans import Translator
+from loguru import logger
+from tqdm import tqdm
+
+
+class TranslationCache:
+ def __init__(self, cache_file="translation_cache.json"):
+ self.cache_file = cache_file
+ self.cache = {}
+
+ def __enter__(self):
+ if os.path.exists(self.cache_file):
+ with open(self.cache_file, "r", encoding="utf-8") as f:
+ self.cache = json.load(f)
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ with open(self.cache_file, "w", encoding="utf-8") as f:
+ json.dump(self.cache, f, ensure_ascii=False, indent=2)
+ self.cache.clear() # ๋ฉ๋ชจ๋ฆฌ ํด์
+
+ def get_translation(self, text):
+ return self.cache.get(text)
+
+ def add_translation(self, text, translation):
+ self.cache[text] = translation
+
+
+def translate_with_retry(translator, text, cache, max_retries=3):
+ # ์บ์์์ ๋ฒ์ญ ํ์ธ
+ cached_translation = cache.get_translation(text)
+ if cached_translation:
+ return cached_translation
+
+ # ์บ์์ ์๋ ๊ฒฝ์ฐ ๋ฒ์ญ ์ํ
+ for attempt in range(max_retries):
+ try:
+ translated = translator.translate(text, src="en", dest="ko").text
+ time.sleep(0.01)
+ cache.add_translation(text, translated)
+ return translated
+ except Exception as e:
+ if attempt == max_retries - 1:
+ logger.info(f"๋ฒ์ญ ์คํจ: {str(e)}")
+ return text
+ time.sleep(1)
+
+
+def translate_list_column(text, cache):
+ items = literal_eval(text)
+ translator = Translator()
+ translated_items = []
+
+ for item in items:
+ translated = translate_with_retry(translator, item, cache)
+ translated_items.append(translated)
+ return translated_items
+
+
+def translate_column(texts, cache):
+ translator = Translator()
+ translated_texts = []
+
+ # ์ค๋ณต ์ ๊ฑฐ๋ฅผ ์ํด ์ ๋ํฌํ ํ
์คํธ๋ง ์ถ์ถ
+ unique_texts = list(set(texts))
+
+ for text in tqdm(unique_texts, desc="๊ณ ์ ํ
์คํธ ๋ฒ์ญ ์ค"):
+ translated = translate_with_retry(translator, text, cache)
+
+ # ์๋ณธ ์์๋๋ก ์บ์์์ ๋ฒ์ญ ๊ฐ์ ธ์ค๊ธฐ
+ for text in texts:
+ translated = cache.get_translation(text)
+ translated_texts.append(translated)
+
+ return translated_texts
diff --git a/data_viz/csv2pdf.py b/data_viz/csv2pdf.py
new file mode 100644
index 0000000..c08e0fb
--- /dev/null
+++ b/data_viz/csv2pdf.py
@@ -0,0 +1,132 @@
+import argparse
+import ast
+import os
+import sys
+import textwrap
+
+import pandas as pd
+from reportlab.lib.pagesizes import A4
+from reportlab.lib.units import cm
+from reportlab.pdfbase import pdfmetrics
+from reportlab.pdfbase.ttfonts import TTFont
+from reportlab.pdfgen import canvas
+
+
+def draw_wrapped_text(c, text, x, y, max_width, max_height, line_height=14):
+ """ํ
์คํธ๋ฅผ ํ์ด์ง ๋๋น์ ๋ง๊ฒ ์ค๋ฐ๊ฟํ์ฌ ์ถ๋ ฅํ๋ฉฐ, ํ์ด์ง ๋์ด๋ฅผ ์ด๊ณผํ๋ฉด ํ์ด์ง๋ฅผ ๋๊น."""
+ wrapped_text = textwrap.fill(text, width=70)
+ text_obj = c.beginText(x, y)
+ text_obj.setFont("NanumGothic", 10)
+ text_obj.setLeading(line_height)
+
+ for line in wrapped_text.splitlines():
+ if text_obj.getY() < max_height:
+ c.drawText(text_obj)
+ c.showPage()
+ text_obj = c.beginText(x, A4[1] - 3 * cm)
+ text_obj.setFont("NanumGothic", 10)
+ text_obj.setLeading(line_height)
+ text_obj.textLine(line)
+
+ c.drawText(text_obj)
+
+
+def create_csat_style_pdf(data, filename):
+ c = canvas.Canvas(filename, pagesize=A4)
+ width, height = A4
+
+ for index, row in data.iterrows():
+ try:
+ # ๋ฌธ์ ID ์ถ๋ ฅ
+ c.setFont("NanumGothic", 12)
+ c.drawString(1 * cm, height - 1 * cm, f"๋ฌธ์ ID: {row['id']}")
+
+ # ๋ณธ๋ฌธ ์ถ๋ ฅ
+ paragraph = row["paragraph"]
+ draw_wrapped_text(c, paragraph, 1 * cm, height - 2 * cm, max_width=85, max_height=14 * cm)
+
+ # ๋ฌธ์ ์ถ๋ ฅ
+ problem_data = ast.literal_eval(row["problems"])
+ question = problem_data["question"]
+ draw_wrapped_text(
+ c,
+ f"๋ฌธ์ : {question}",
+ 1 * cm,
+ height - 16 * cm,
+ max_width=85,
+ max_height=8 * cm,
+ )
+
+ # ์ ํ์ง ์ถ๋ ฅ
+ choices = problem_data["choices"]
+ choice_y = height - 19 * cm
+ for i, choice in enumerate(choices, 1):
+ draw_wrapped_text(
+ c,
+ f"{i}. {choice}",
+ 1 * cm,
+ choice_y,
+ max_width=85,
+ max_height=3 * cm,
+ )
+ choice_y -= 1.2 * cm
+
+ # ์ ๋ต ํ์
+ answer = problem_data["answer"]
+ draw_wrapped_text(
+ c,
+ f"์ ๋ต: {answer}",
+ 1 * cm,
+ choice_y - 1 * cm,
+ max_width=85,
+ max_height=3 * cm,
+ )
+
+ c.showPage()
+ except KeyError as ke:
+ print(f"Data format error: ํ์ ํค๊ฐ ์์ต๋๋ค. {ke}")
+
+ # PDF ์ ์ฅ
+ c.save()
+
+
+if __name__ == "__main__":
+ # ์ธ์ ํ์ ์ค์
+ parser = argparse.ArgumentParser(description="Generate a CSAT style PDF from CSV data.")
+ parser.add_argument(
+ "--csv_path",
+ default="../data/train.csv",
+ help="Path to the CSV file containing the data.",
+ )
+ args = parser.parse_args()
+
+ # CSV ํ์ผ ์ฝ๊ธฐ ๋ฐ ์ปฌ๋ผ ํ์ธ
+ try:
+ df = pd.read_csv(args.csv_path)
+ required_columns = {"id", "paragraph", "problems"}
+ if not required_columns.issubset(df.columns):
+ missing_columns = required_columns - set(df.columns)
+ raise ValueError(f"CSV ํ์ผ์ ํ์ํ ์ปฌ๋ผ์ด ์์ต๋๋ค: {', '.join(missing_columns)}")
+ except FileNotFoundError:
+ print(f"CSV ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค: {args.csv_path}")
+ sys.exit(1)
+ except ValueError as ve:
+ print(ve)
+ sys.exit(1)
+
+ # ํ๊ธ ํฐํธ ๋ฑ๋ก
+ font_path = os.path.abspath("../data/NanumGothic.ttf")
+ if not os.path.isfile(font_path):
+ print(
+ f"ํฐํธ ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค: {font_path}"
+ f"https://hangeul.naver.com/fonts/search?f=nanum ์์ ํฐํธ๋ฅผ ๋ค์ด๋ฐ์ dataํด๋์ ๋ฃ์ด์ฃผ์ธ์."
+ )
+ sys.exit(1)
+
+ pdfmetrics.registerFont(TTFont("NanumGothic", font_path))
+
+ # PDF ํ์ผ๋ช
์ค์ (์
๋ ฅ ํ์ผ๊ณผ ๋์ผ ๊ฒฝ๋ก ๋ฐ ์ด๋ฆ์ผ๋ก ์ค์ , ํ์ฅ์๋ง .pdf๋ก ๋ณ๊ฒฝ)
+ pdf_filename = os.path.splitext(args.csv_path)[0] + ".pdf"
+
+ # PDF ์์ฑ ํจ์ ํธ์ถ
+ create_csat_style_pdf(df, pdf_filename)
diff --git a/data_viz/labeling.py b/data_viz/labeling.py
new file mode 100644
index 0000000..9c60556
--- /dev/null
+++ b/data_viz/labeling.py
@@ -0,0 +1,98 @@
+from ast import literal_eval
+
+import pandas as pd
+import streamlit as st
+
+
+def load_data(file_path):
+ data = pd.read_csv(file_path)
+ records = []
+ for _, row in data.iterrows():
+ problems = literal_eval(row["problems"])
+ record = {
+ "id": row["id"],
+ "paragraph": row["paragraph"],
+ "question": problems["question"],
+ "choices": problems["choices"],
+ "answer": problems.get("answer", None),
+ "question_plus": problems.get("question_plus", None),
+ "target": problems.get("target", None),
+ "suggested_label": problems.get("suggested_label", None),
+ "is_label_issue": problems.get("is_label_issue", None),
+ }
+ records.append(record)
+ return data, records
+
+
+def display_instance(record):
+ st.subheader("Paragraph")
+ st.write(record["paragraph"])
+
+ st.subheader("Question, Choices, Answer")
+ st.markdown("#### Question:")
+ st.write(record["question"])
+
+ st.markdown("#### Choices:")
+ for i, choice in enumerate(record["choices"], 1):
+ st.write(f"{i} : {choice}")
+
+ st.markdown("#### Answer:")
+ st.write(str(record["answer"]))
+
+ st.subheader("Question Plus")
+ st.write(str(record["question_plus"]))
+
+
+def main():
+ st.title("CSV ๋ฐ์ดํฐ ์ธ์คํด์ค ๋ทฐ์ด")
+
+ data, records = load_data("../data/cleaned_output_with_labels_CL.csv")
+
+ # 1055๋ฒ์งธ ์ธ๋ฑ์ค๋ฅผ ๊ธฐ์ค์ผ๋ก ๋ฐ์ดํฐ ๋ถํ
+ split_index = 792
+ before_split = data.iloc[:split_index]
+ after_split = data.iloc[split_index:]
+
+ # 1055 ์ด์ ๊ณผ ์ดํ์ suggested_label ๊ฐ์ ๊ณ์ฐ
+ label_counts = {
+ "Before 1380": {
+ "Label 0": (before_split["suggested_label"] == 0).sum(),
+ "Not Label 0": (before_split["suggested_label"] != 0).sum(),
+ },
+ "After 1380": {
+ "Label 1": (after_split["suggested_label"] == 1).sum(),
+ "Not Label 1": (after_split["suggested_label"] != 1).sum(),
+ },
+ }
+
+ # ๊ฒฐ๊ณผ๋ฅผ ๋ฐ์ดํฐํ๋ ์์ผ๋ก ๋ณํ
+ label_counts_df = pd.DataFrame(label_counts)
+
+ # ๋ผ๋ฒจ ๊ฐ์ ์ถ๋ ฅ
+ st.subheader("Suggested Label ๊ฐ์")
+ st.write(label_counts_df)
+
+ # Before Split์์ suggested_label์ด 1์ธ ํ ํํฐ๋ง
+ before_split_label_1 = before_split[before_split["suggested_label"] == 1]
+
+ # After Split์์ suggested_label์ด 1์ธ ํ ํํฐ๋ง
+ after_split_label_1 = after_split[after_split["suggested_label"] == 0]
+
+ # ๊ฒฐ๊ณผ ์ถ๋ ฅ
+ st.subheader("Before 1380์์ suggested_label์ด 1์ธ ์ธ์คํด์ค")
+ st.write(before_split_label_1)
+
+ st.subheader("After 1380์์ suggested_label์ด 0์ธ ์ธ์คํด์ค")
+ st.write(after_split_label_1)
+
+ # ์ธ์คํด์ค ์ ํ ๊ธฐ๋ฅ ์ถ๊ฐ
+ instance_index = st.number_input("์ธ์คํด์ค ์ ํ", min_value=0, max_value=len(data) - 1, value=0, step=1)
+
+ st.write(f"์ ํ๋ ์ธ์คํด์ค (์ธ๋ฑ์ค {instance_index}):")
+ st.write(data.iloc[instance_index])
+
+ display_instance(records[instance_index])
+
+
+if __name__ == "__main__":
+ main()
diff --git a/data_viz/streamlit_app.py b/data_viz/streamlit_app.py
new file mode 100644
index 0000000..51153d9
--- /dev/null
+++ b/data_viz/streamlit_app.py
@@ -0,0 +1,71 @@
+from ast import literal_eval
+import re
+
+import pandas as pd
+import streamlit as st
+
+
+def load_data(file_path):
+ data = pd.read_csv(file_path)
+ records = []
+ for _, row in data.iterrows():
+ problems = literal_eval(row["problems"])
+ record = {
+ "id": row["id"],
+ "paragraph": row["paragraph"],
+ "question": problems["question"],
+ "choices": problems["choices"],
+ "answer": problems.get("answer", None),
+ "question_plus": problems.get("question_plus", None),
+ "documents": row.get("documents", None),
+ }
+ records.append(record)
+ return data, records
+
+
+def display_instance(left, right, record):
+ with left:
+ st.subheader("Paragraph")
+ st.write(record["paragraph"])
+
+ st.subheader("Question, Choices, Answer")
+ st.markdown("#### Question:")
+ st.write(record["question"])
+
+ st.markdown("#### Choices:")
+ for i, choice in enumerate(record["choices"], 1):
+ if i == record["answer"]:
+ st.markdown(f"{i} : {choice}", unsafe_allow_html=True)
+ else:
+ st.write(f"{i} : {choice}")
+
+ st.markdown("#### Answer:")
+ st.write(str(record["answer"]))
+
+ st.subheader("Question Plus")
+ st.write(str(record["question_plus"]))
+ with right:
+ st.subheader("Documents")
+ result = re.split(r"(?=\[)", str(record["documents"]))
+ for part in result:
+ st.write(part)
+
+
+def main(file_path="../data/train.csv"):
+ st.set_page_config(layout="wide")
+ st.title("CSV ๋ฐ์ดํฐ ์ธ์คํด์ค ๋ทฐ์ด")
+ data, records = load_data(file_path)
+
+ left, right = st.columns([0.5, 0.5])
+ with left:
+ instance_index = st.number_input("์ธ์คํด์ค ์ ํ", min_value=0, max_value=len(data) - 1, value=0, step=1)
+
+ display_instance(left, right, records[instance_index])
+
+ with left:
+ st.write(f"์ ํ๋ ์ธ์คํด์ค (์ธ๋ฑ์ค {instance_index}):")
+ st.write(data.iloc[instance_index])
+
+
+if __name__ == "__main__":
+ main("../data/train_retrieve.csv")
diff --git a/ensemble/hard_voting.py b/ensemble/hard_voting.py
new file mode 100644
index 0000000..8965761
--- /dev/null
+++ b/ensemble/hard_voting.py
@@ -0,0 +1,90 @@
+from collections import Counter
+import csv
+from glob import glob
+
+
+"""
+# ํ๋ ๋ณดํ
์์๋ธ ์ฌ์ฉ ๋ฐฉ๋ฒ (CSV ๋ฒ์ )
+
+1. ํ์ผ ์ค๋น:
+ - 'ensemble/results_hard' ํด๋ ์์ ์์๋ธํ๊ณ ์ถ์ ๋ชจ๋ CSV ํ์ผ๋ค์ ๋ฃ์ต๋๋ค.
+
+2. ์ฐ์ ์์ ์ค์ :
+ - 'priority_order' ๋ฆฌ์คํธ์ ๋ชจ๋ธ์ ์ฐ์ ์์๋ฅผ ์ ์ํฉ๋๋ค.
+ - ์: priority_order = ['predictions1.csv', 'predictions2.csv', 'predictions3.csv']
+ - ๋ฆฌ์คํธ์ ์์ชฝ์ ์๋ ๋ชจ๋ธ์ผ์๋ก ๋์ ์ฐ์ ์์๋ฅผ ๊ฐ์ง๋๋ค.
+
+3. ์ฝ๋ ์คํ:
+ - ์ค์ ์ ๋ง์น ํ ์ฝ๋๋ฅผ ์คํํฉ๋๋ค.
+ - ์ฝ๋๋ ์๋์ผ๋ก ํด๋ ๋ด์ ๋ชจ๋ CSV ํ์ผ์ ์ฝ์ด ์์๋ธ์ ์ํํฉ๋๋ค.
+
+4. ํ๋ ๋ณดํ
๊ณผ์ :
+ - ๊ฐ ์ง๋ฌธ์ ๋ํด ๋ชจ๋ ๋ชจ๋ธ์ ๋ต๋ณ์ ์์งํฉ๋๋ค.
+ - ๊ฐ์ฅ ๋ง์ด ๋์จ ๋ต๋ณ(๋ค)์ ์ ํํฉ๋๋ค.
+ - ๋์ ์ธ ๊ฒฝ์ฐ, ์ฐ์ ์์๊ฐ ๊ฐ์ฅ ๋์ ๋ชจ๋ธ์ ๋ต๋ณ์ ์ ํํฉ๋๋ค.
+
+5. ๊ฒฐ๊ณผ ํ์ธ:
+ - ์์๋ธ ๊ฒฐ๊ณผ๋ 'final_hard_predictions.csv' ํ์ผ๋ก ์ ์ฅ๋ฉ๋๋ค.
+ - ์ด ํ์ผ์๋ ๊ฐ ์ง๋ฌธ์ ๋ํ ์ต์ข
๋ต๋ณ์ด ํฌํจ๋์ด ์์ต๋๋ค.
+
+์ฃผ์: ๋ชจ๋ธ์ ์ฐ์ ์์๋ ๊ฐ ๋ชจ๋ธ์ ์ฑ๋ฅ์ด๋ ํน์ฑ์ ๊ณ ๋ คํ์ฌ ์ ์คํ ๊ฒฐ์ ํด์ผ ํฉ๋๋ค.
+์ฐ์ ์์ ์ค์ ์ ๋ฐ๋ผ ์ต์ข
๊ฒฐ๊ณผ๊ฐ ํฌ๊ฒ ๋ฌ๋ผ์ง ์ ์์ต๋๋ค.
+"""
+
+# ์ฐ์ ์์ ์ง์ ์ ์
+priority_order = ["output (1).csv", "output (7).csv", "output (8).csv"]
+
+
+def hard_voting_with_priority(predictions, priority_order):
+ result = {}
+ for id in predictions[0].keys():
+ answers = [pred[id] for pred in predictions if id in pred]
+ answer_counts = Counter(answer for answer in answers if answer)
+
+ if answer_counts:
+ max_count = max(answer_counts.values())
+ top_answers = [ans for ans, count in answer_counts.items() if count == max_count]
+
+ if len(top_answers) == 1:
+ result[id] = top_answers[0]
+ else:
+ for model in priority_order:
+ model_index = next((i for i, pred in enumerate(predictions) if pred.get("filename") == model), None)
+ if model_index is not None:
+ model_answer = predictions[model_index].get(id, "")
+ if model_answer in top_answers:
+ result[id] = model_answer
+ break
+ else:
+ result[id] = top_answers[0]
+ else:
+ result[id] = ""
+ return result
+
+
+# ์์ธก ํ์ผ๋ค์ ๋ก๋ํฉ๋๋ค.
+prediction_files = glob("./results_hard/*.csv")
+predictions = []
+
+# ๊ฐ prediction ํ์ผ์ ์ฝ์ด์์ predictions ๋ฆฌ์คํธ์ ์ถ๊ฐํฉ๋๋ค.
+for file_name in prediction_files:
+ prediction = {}
+ with open(file_name, "r", encoding="utf-8") as file:
+ csv_reader = csv.reader(file)
+ next(csv_reader) # ํค๋ ํ์ ๊ฑด๋๋๋๋ค
+ for row in csv_reader:
+ prediction[row[0]] = row[1] # ์ฒซ ๋ฒ์งธ ์ด์ ํค๋ก, ๋ ๋ฒ์งธ ์ด์ ๊ฐ์ผ๋ก ์ฌ์ฉ
+ # Remove filename key from prediction dictionary
+ predictions.append(prediction)
+
+# ํ๋ ๋ณดํ
์ ์ํํฉ๋๋ค.
+final_predictions = hard_voting_with_priority(predictions, priority_order)
+
+# ๊ฒฐ๊ณผ๋ฅผ CSV ํ์ผ๋ก ์ ์ฅํฉ๋๋ค.
+with open("final_hard_predictions.csv", "w", newline="", encoding="utf-8") as f:
+ writer = csv.writer(f)
+ writer.writerow(["id", "answer"]) # ํค๋ ์์ฑ
+ for id, answer in final_predictions.items():
+ writer.writerow([id, answer])
+
+print("์์๋ธ ๊ฒฐ๊ณผ๊ฐ 'final_hard_predictions.csv' ํ์ผ๋ก ์ ์ฅ๋์์ต๋๋ค.")
diff --git a/.github/.keep b/ensemble/results_hard/.gitkeep
similarity index 100%
rename from .github/.keep
rename to ensemble/results_hard/.gitkeep
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..db9b2f8
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,48 @@
+[tool.ruff]
+line-length = 120
+
+# Exclude the following files and directories.
+exclude = [
+ ".git",
+ ".hg",
+ ".mypy_cache",
+ ".tox",
+ ".venv",
+ "_build",
+ "buck-out",
+ "build",
+ "dist",
+ "env",
+ "venv",
+ "**/*.ipynb", # Jupyter Notebook ํ์ผ ์ ์ธ
+]
+
+[tool.ruff.lint]
+# Never enforce `E501` (line length violations).
+extend-select = ["C901", "E501", "E402"]
+select = ["C", "E", "F", "I", "W"]
+
+
+[tool.ruff.lint.per-file-ignores]
+"__init__.py" = ["E402", "F401", "F403", "F811"]
+
+[tool.ruff.lint.isort]
+lines-after-imports = 2
+
+# Setting the order of sections
+section-order = ["standard-library", "third-party", "local-folder"]
+combine-as-imports = true
+force-sort-within-sections = true
+
+[tool.ruff.format]
+# Like Black, use double quotes for strings.
+quote-style = "double"
+
+# Like Black, indent with spaces, rather than tabs.
+indent-style = "space"
+
+# Like Black, respect magic trailing commas.
+skip-magic-trailing-comma = false
+
+# Like Black, automatically detect the appropriate line ending.
+line-ending = "auto"
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..49cc160
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,51 @@
+# CUDA Version: 12.2
+# Ubuntu 20.04.6
+# python 3.10.13
+
+# Deep Learning
+auto_gptq==0.7.1
+bitsandbytes==0.44.1
+evaluate==0.4.3
+huggingface-hub==0.26.2
+numpy==2.0.0
+optimum==1.23.3
+peft==0.5.0
+scikit-learn==1.5.2
+torch==2.5.1 # 2.5.1+cu124
+tqdm==4.67.0
+transformers==4.46.2
+trl==0.12.0
+wandb==0.18.5
+
+# RAG
+elasticsearch==8.16.0
+konlpy==0.6.0
+rank-bm25==0.2.2
+wikiextractor==3.0.6
+faiss-cpu==1.9.0 # faiss-gpu==1.7.2
+
+# Utils
+beautifulsoup4==4.12.3
+ipykernel==6.29.5
+ipywidgets==8.1.5
+loguru==0.7.2
+matplotlib==3.9.2
+python-dotenv==1.0.1
+reportlab==4.2.5
+streamlit==1.40.1
+pdfminer.six==20240706
+
+# Google Drive API
+google-api-python-client==2.151.0
+google-auth-httplib2==0.2.0
+google-auth-oauthlib==1.2.1
+
+# Automatically installed dependencies
+# pandas==2.2.3
+# pyarrow==18.0.0
+# datasets==3.1.0
+# safetensors==0.4.5
+# scipy==1.14.1
+# tqdm==4.67.0
+# PyYAML==6.0.2
+# requests==2.32.3
diff --git a/script/run_script.bash b/script/run_script.bash
new file mode 100755
index 0000000..c564b27
--- /dev/null
+++ b/script/run_script.bash
@@ -0,0 +1,11 @@
+#!/bin/bash
+
+# ์ฒซ ๋ฒ์งธ ์คํ
+nohup python -u main.py --config config-normal.yaml &
+PYTHON_PID_1=$!
+wait $PYTHON_PID_1
+
+# ๋ ๋ฒ์งธ ์คํ
+nohup python -u main.py --config config-rag.yaml &
+PYTHON_PID_2=$!
+wait $PYTHON_PID_2
diff --git a/script/run_with_gpu_monitoring.bash b/script/run_with_gpu_monitoring.bash
new file mode 100755
index 0000000..fecac05
--- /dev/null
+++ b/script/run_with_gpu_monitoring.bash
@@ -0,0 +1,31 @@
+#!/bin/bash
+
+# adamw_torch ์ค์
+nvidia-smi --query-gpu=timestamp,name,utilization.gpu,memory.used,memory.free --format=csv -l 1 > ../log/gpu_log_adamw_torch.csv &
+NVIDIA_LOG_PID_1=$!
+nohup python -u main.py --config config-adamw_torch.yaml &
+PYTHON_PID_1=$!
+wait $PYTHON_PID_1
+
+# ์ฒซ ๋ฒ์งธ ๋ชจ๋ํฐ๋ง ํ๋ก์ธ์ค ์ข
๋ฃ
+kill $NVIDIA_LOG_PID_1
+
+# adafactor ์ค์
+nvidia-smi --query-gpu=timestamp,name,utilization.gpu,memory.used,memory.free --format=csv -l 1 > ../log/gpu_log_adafactor.csv &
+NVIDIA_LOG_PID_2=$!
+nohup python -u main.py --config config-adafactor.yaml &
+PYTHON_PID_2=$!
+wait $PYTHON_PID_2
+
+# ๋ ๋ฒ์งธ ๋ชจ๋ํฐ๋ง ํ๋ก์ธ์ค ์ข
๋ฃ
+kill $NVIDIA_LOG_PID_2
+
+# adamw_bnb_8bit ์ค์
+nvidia-smi --query-gpu=timestamp,name,utilization.gpu,memory.used,memory.free --format=csv -l 1 > ../log/gpu_log_adamw_bnb_8bit.csv &
+NVIDIA_LOG_PID_3=$!
+nohup python -u main.py --config config-adamw_bnb_8bit.yaml &
+PYTHON_PID_3=$!
+wait $PYTHON_PID_3
+
+# ์ธ ๋ฒ์งธ ๋ชจ๋ํฐ๋ง ํ๋ก์ธ์ค ์ข
๋ฃ
+kill $NVIDIA_LOG_PID_3
diff --git a/script/setup-gpu-server.bash b/script/setup-gpu-server.bash
new file mode 100755
index 0000000..e73dcb0
--- /dev/null
+++ b/script/setup-gpu-server.bash
@@ -0,0 +1,111 @@
+#!/bin/bash
+
+##########################################
+# GPU ์๋ฒ ์ธ์คํด์ค ์์ฑ ์ ํ์ํ ๊ฐ๋ฐ ํ๊ฒฝ ์ธํ
+# conda ๋ฏธ์ค์น ํ๊ฒฝ์์๋ conda ์ค์น ๊ณผ์ ์ ์ถ๊ฐ
+# ์ ์ ๋ช
/ ๋๋ ํ ๋ฆฌ / ๊ถํ ์ค์ ๋ฑ ์์ ํ์ฌ ์ฌ์ฉ
+##########################################
+
+##################### Install #####################
+apt-get update
+apt-get install -y sudo
+sudo apt-get install -y wget git vim build-essential
+
+##################### Set root password #####################
+echo "root:root" | chpasswd
+
+##################### conda #####################
+export PATH="/opt/conda/bin:$PATH"
+conda init bash
+conda config --set auto_activate_base false
+source ~/.bashrc
+conda create -n main python=3.10.13 -y
+sudo chmod -R 777 /opt/conda/env
+
+##################### Users: dir & permission #####################
+users=("camper")
+
+for i in "${!users[@]}"; do
+ user="${users[$i]}"
+ user_folder="/data/ephemeral/home/$user"
+
+ # Create user with custom home directory and give sudo privileges
+ sudo mkdir -p $user_folder
+ sudo chmod 777 $user_folder
+ sudo adduser --disabled-password --home $user_folder --gecos "" $user
+ # Set user password same as username
+ echo "${user}:${user}" | sudo chpasswd
+ sudo chsh -s /bin/bash $user
+ echo "$user ALL=(ALL) NOPASSWD:ALL" | sudo tee /etc/sudoers.d/$user
+
+done
+
+##################### Users: conda #####################
+for user in "${users[@]}"; do
+ user_folder="/data/ephemeral/home/$user"
+
+ # Add conda to each user's PATH and initialize conda
+ su - $user bash -c 'export PATH="/opt/conda/bin:$PATH"; conda init bash; conda config --set auto_activate_base false; source ~/.bashrc;'
+ echo "cd $user_folder" | sudo tee -a $user_folder/.bashrc
+ echo 'conda activate main' | sudo tee -a $user_folder/.bashrc
+
+ # Add local bin path to each user's .bashrc
+ echo "export PATH=\$PATH:/data/ephemeral/home/$user/.local/bin" | sudo tee -a $user_folder/.bashrc
+
+ sudo chmod -R 777 $user_folder
+ sudo chown -R $user:$user $user_folder
+
+done
+
+##################### Git #####################
+users=("sujin" "seongmin" "sungjae" "gayeon" "yeseo" "minseo")
+BASE_DIR="/data/ephemeral/home/camper"
+
+# ๊ฐ ์ฌ์ฉ์๋ณ ๋๋ ํ ๋ฆฌ ์์ฑ
+for user in "${users[@]}"; do
+ mkdir -p "$BASE_DIR/$user"
+done
+
+# ๊ธ๋ก๋ฒ .gitconfig ์์ฑ
+cat << EOF > "$BASE_DIR/.gitconfig"
+[user]
+ name = Camper User
+ email = camper@example.com
+
+# ์ฌ์ฉ์๋ณ ํด๋ ์ค์ ํฌํจ
+EOF
+
+# includeIf ์ค์ ์ ๋์ ์ผ๋ก ์ถ๊ฐ
+for user in "${users[@]}"; do
+ cat << EOF >> "$BASE_DIR/.gitconfig"
+[includeIf "gitdir:$BASE_DIR/$user/"]
+ path = $BASE_DIR/$user/.gitconfig
+EOF
+done
+
+# ๊ฐ ์ฌ์ฉ์ ํด๋์ .gitconfig ์์ฑ
+for user in "${users[@]}"; do
+ cat << EOF > "$BASE_DIR/$user/.gitconfig"
+[user]
+ name = $user
+ email = $user@example.com
+EOF
+done
+
+# ๊ถํ ์ค์
+chown -R camper:camper "$BASE_DIR"
+chmod -R 755 "$BASE_DIR"
+
+echo "Git configuration setup completed!"
+
+echo "Setup complete!"
+
+
+
+##### git์ ๊ฐ์ ํด๋์์ ์ธํ
#####
+# git clone https://"$token"@github.com/boostcampaitech7/level2-nlp-generationfornlp-nlp-02-lv3.git
+
+# git config --local user.email "$email"
+# git config --local user.name "$username"
+# git config --local credential.helper "cache --timeout=360000"
+# git config --local commit.template .gitmessage.txt
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 0000000..06354ac
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,2 @@
+[tool:pytest]
+addopts = -ra -v -l
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/test_sample.py b/tests/test_sample.py
new file mode 100644
index 0000000..d2b4018
--- /dev/null
+++ b/tests/test_sample.py
@@ -0,0 +1,2 @@
+def test_add():
+ assert 1 + 2 == 3