diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100755 index 0000000..73fa5dc --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,11 @@ +blank_issues_enabled: false +issue_templates: + - name: Feature Template + description: Suggest an feature for this project ๐Ÿ‘ฉโ€๐Ÿ’ป + file: feature.md + - name: Experiment Template + description: Suggest an experiment for this project ๐Ÿง‘๐Ÿปโ€๐Ÿ”ฌ + file: experiment.md + - name: Research Template + description: Suggest an research to generate ideas ๐Ÿ‘จโ€๐Ÿซ + file: research.md diff --git a/.github/ISSUE_TEMPLATE/experiment.md b/.github/ISSUE_TEMPLATE/experiment.md new file mode 100755 index 0000000..2ad250e --- /dev/null +++ b/.github/ISSUE_TEMPLATE/experiment.md @@ -0,0 +1,38 @@ +--- +name: ๐Ÿ“Š Experiment Request +about: Suggest an experiment for this project ๐Ÿง‘๐Ÿปโ€๐Ÿ”ฌ +title: "[EXP]" +labels: experiment +assignees: +--- +# ๐Ÿ“Š Experiment + +## ๐Ÿฅš ์‹คํ—˜ ๊ทผ๊ฑฐ + +- ๋ ˆํผ๋Ÿฐ์Šค (๋…ผ๋ฌธ, ๊ฐ•์˜, ํฌ์ŠคํŒ…) +- ํ•ฉ๋ฆฌ์  ์ถ”๋ก  + +## ๐Ÿ“Ž ๋‚ด์šฉ + +- ์‹คํ—˜์— ๋Œ€ํ•œ ์ƒ์„ธํ•œ ๋‚ด์šฉ +- ์‹คํ—˜ ํ™˜๊ฒฝ๊ณผ ๋ณ€์ธ ํ†ต์ œ ๋ฐ˜๋“œ์‹œ ๊ธฐ์ž… + +## ๐Ÿฃ ์˜ˆ์ƒ ๊ฒฐ๊ณผ + +- ๋ฐ˜๋“œ์‹œ ์ด์œ ์™€ ํ•จ๊ป˜ ์˜ˆ์ƒ ๊ฒฐ๊ณผ ์ž‘์„ฑ + +## ๐Ÿณ ์‹ค์ œ ๊ฒฐ๊ณผ + +- ์˜ˆ์ƒ ๊ฒฐ๊ณผ์™€ ๋‹ฌ๋ž๋‹ค๋ฉด ๊ทธ ์ด์œ ๋„ ํ•จ๊ป˜ ์ž‘์„ฑ + +## ๐Ÿ“ ์‹คํ—˜ ์ •๋ณด + +- wandb ๋งํฌ +- ์ œ์ถœ ๊ฒฐ๊ณผ ๋…ธ์…˜ ๋งํฌ + +## ๐Ÿ“Œ ์ฒดํฌ๋ฆฌ์ŠคํŠธ + +- [ ] todo 1 +- [ ] todo 2 + +--- diff --git a/.github/ISSUE_TEMPLATE/feature.md b/.github/ISSUE_TEMPLATE/feature.md new file mode 100755 index 0000000..eaaddba --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature.md @@ -0,0 +1,20 @@ +--- +name: ๐Ÿš€ Feature Request +about: Suggest an feature for this project ๐Ÿ‘ฉโ€๐Ÿ’ป +title: "[FEAT]" +labels: enhancement +assignees: +--- +# ๐Ÿš€ Feature + +## ๐Ÿ“Ž ๋‚ด์šฉ + +- context 1 +- context 2 + +## ๐Ÿ“Œ ์ฒดํฌ๋ฆฌ์ŠคํŠธ + +- [ ] todo 1 +- [ ] todo 2 + +--- diff --git a/.github/ISSUE_TEMPLATE/research.md b/.github/ISSUE_TEMPLATE/research.md new file mode 100755 index 0000000..4395cc6 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/research.md @@ -0,0 +1,25 @@ +--- +name: ๐Ÿ“š Research Request +about: Suggest an research to generate ideas ๐Ÿ‘จโ€๐Ÿซ +title: "[RES]" +labels: research +assignees: +--- +# ๐Ÿ“š Research + +## ๐Ÿ“Ž ๋‚ด์šฉ + +- context 1 +- context 2 + +## ๐Ÿง ๊ฒฐ๋ก  ๋ฐ ์‹คํ—˜ ๊ฐ€๋Šฅ์„ฑ + +- conclusion 1 +- conclusion 2 + +## ๐Ÿ“Œ ์ฒดํฌ๋ฆฌ์ŠคํŠธ + +- [ ] todo 1 +- [ ] todo 2 + +--- diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..c8e6346 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,11 @@ +## Description + +- ์ด๋ฒˆ PR์—์„œ ์ž‘์—…ํ•œ ๋‚ด์šฉ์„ ๊ฐ„๋žตํžˆ ์„ค๋ช… + +## Refer to the reviewer + +- ๋ฆฌ๋ทฐ์–ด์—๊ฒŒ ํ•„์š”ํ•œ ์„ค๋ช…์ด๋‚˜ ํŠน๋ณ„ํžˆ ๋ด์ฃผ์—ˆ์œผ๋ฉด ํ•˜๋Š” ๋ถ€๋ถ„์„ ์ž‘์„ฑ + +## Related Issue + +- #์ด์Šˆ๋ฒˆํ˜ธ diff --git a/.github/workflows/check-lint.yml b/.github/workflows/check-lint.yml new file mode 100755 index 0000000..c121e18 --- /dev/null +++ b/.github/workflows/check-lint.yml @@ -0,0 +1,24 @@ +name: check-lint + +on: [pull_request] + +jobs: + check-lint: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python 3.11 + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install dependencies + run: | + python3 -m pip install --upgrade pip + + - name: Check Lint + run: | + make quality diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4c78c18 --- /dev/null +++ b/.gitignore @@ -0,0 +1,177 @@ +# Custom +.idea/ +**/data/ +**/output/ +**/outputs/ +**/wandb/ +**/*.out +config/*.yaml +config/token.json +config/credentials.json + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Mac +**/.DS_Store +.vscode/settings.json diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..660d22a --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,25 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files + - id: check-merge-conflict + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.7.2 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format + + - repo: local + hooks: + - id: pytest + name: pytest + entry: python3 -m pytest + language: system + pass_filenames: false + types: [python] diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..1dc5558 --- /dev/null +++ b/Makefile @@ -0,0 +1,48 @@ +clean: clean-pyc clean-test +quality: set-style-dep check-quality +style: set-style-dep set-style +setup: set-precommit set-style-dep set-test-dep set-git set-dev +test: set-test-dep set-test + + +##### basic ##### +set-git: + git config --local commit.template .gitmessage + +set-style-dep: + pip3 install ruff==0.7.2 + +set-test-dep: + pip3 install pytest==8.3.2 + +set-precommit: + pip3 install pre-commit==4.0.1 + pre-commit install + +set-dev: + pip3 install -r ./requirements.txt + +set-test: + python3 -m pytest tests/ + +set-style: + ruff check --fix . + ruff format . + +check-quality: + ruff check . + ruff format --check . + +##### clean ##### +clean-pyc: + find . -name '*.pyc' -exec rm -f {} + + find . -name '*.pyo' -exec rm -f {} + + find . -name '*~' -exec rm -f {} + + find . -name '__pycache__' -exec rm -fr {} + + +clean-test: + rm -f .coverage + rm -f .coverage.* + rm -rf .pytest_cache + rm -rf .mypy_cache + rm -rf .ruff_cache diff --git a/README.md b/README.md new file mode 100644 index 0000000..b7aaf50 --- /dev/null +++ b/README.md @@ -0,0 +1,196 @@ +
+ +# ๐Ÿ† Lv.2 NLP Project : ์ˆ˜๋Šฅ ๋ฌธ์ œ ํ’€์ด AI ๋ชจ๋ธ ์ƒ์„ฑ +
+ +## โœ๏ธ ๋Œ€ํšŒ ์†Œ๊ฐœ +| ํŠน์ง• | ์„ค๋ช… | +|:------:|--------------------------------------------------------------------------------------------------------| +| ๋Œ€ํšŒ ์ฃผ์ œ | ๋„ค์ด๋ฒ„ ๋ถ€์ŠคํŠธ์บ ํ”„ AI-Tech 7๊ธฐ NLP ํŠธ๋ž™์˜ level 2 Generation for NLP ๋Œ€ํšŒ
| +| ๋Œ€ํšŒ ์„ค๋ช… | ํ•œ๊ตญ์–ด์˜ ํŠน์„ฑ๊ณผ ์ˆ˜๋Šฅ ์‹œํ—˜์˜ ํŠน์ง•์„ ๋ฐ”ํƒ•์œผ๋กœ ์ˆ˜๋Šฅ์— ํŠนํ™”๋œ AI ๋ชจ๋ธ์„ ์ƒ์„ฑํ•˜๋Š” ํ”„๋กœ์ ํŠธ | +| ์ง„ํ–‰ ๊ธฐ๊ฐ„ |2024๋…„ 11์›” 11์›” ~ 2024๋…„ 11์›” 28์ผ| +| ๋ฐ์ดํ„ฐ ๊ตฌ์„ฑ | ํ•™์Šต๋ฐ์ดํ„ฐ ์…‹: KMMLU / MMMLU(Ko) / KLUE MRC ์ค‘ 2031๊ฐœ
ํ‰๊ฐ€๋ฐ์ดํ„ฐ ์…‹: ์ˆ˜๋Šฅํ˜• ๋ฌธ์ œ + KMMLU / MMMLU(Ko) / KLUE MRC ์ด 869๊ฐœ | +| ํ‰๊ฐ€ ์ง€ํ‘œ | ์ •ํ™•๋„(Accuracy) = ๋ชจ๋ธ์ด ๋งž์ถ˜ ๋ฌธ์ œ ์ˆ˜ / ์ „์ฒด ๋ฌธ์ œ ์ˆ˜ | + +## ๐ŸŽ–๏ธ Leader Board +### ๐Ÿฅˆ Public Leader Board (2์œ„) +image +### ๐Ÿฅˆ Priavate Leader Board (2์œ„) +image + +## ๐Ÿ‘จโ€๐Ÿ’ป Contributors + + + + + + + + + +
+ ์ด์˜ˆ์„œ
+ + badge ์ด์˜ˆ์„œ + +
+ ๊น€์ˆ˜์ง„
+ + badge ๊น€์ˆ˜์ง„ + +
+ ๊น€๋ฏผ์„œ
+ + badge ๊น€๋ฏผ์„œ + +
+ ํ™์„ฑ์žฌ
+ + badge ํ™์„ฑ์žฌ + +
+ ์–‘๊ฐ€์—ฐ
+ + badge ์–‘๊ฐ€์—ฐ + +
+ ํ™์„ฑ๋ฏผ
+ + badge ํ™์„ฑ๋ฏผ + +
+ +## ๐Ÿ‘ผ ์—ญํ•  ๋ถ„๋‹ด +| ์ด๋ฆ„ | ์—ญํ•  | +| --- |---------------------------------------------------------------------------------------------| +| ๊น€๋ฏผ์„œ | ์ตœ์ ํ™” ์†”๋ฃจ์…˜(DeepSpeed), ์–‘์žํ™”(Optimizer Quantization), ๋‚œ์ด๋„ ๊ธฐ๋ฐ˜ ๋ฐ์ดํ„ฐ ์ฆ๊ฐ• | +| ๊น€์ˆ˜์ง„ | EDA(๊ตญ์–ด์˜์—ญ๊ณผ ์‚ฌํšŒ์˜์—ญ ์ฐจ์ด ๋ถ„์„), ๋ฐ์ดํ„ฐ ์ˆ˜์ง‘, LLM์„ ํ™œ์šฉํ•œ ๋ฐ์ดํ„ฐ ์ฆ๊ฐ•, ํ”„๋กฌํ”„ํŠธ ์‹คํ—˜ | +| ์–‘๊ฐ€์—ฐ | EDA(๊ตญ์–ด์˜์—ญ๊ณผ ์‚ฌํšŒ์˜์—ญ ์ฐจ์ด ๋ถ„์„), ๋ฐ์ดํ„ฐ ์ˆ˜์ง‘, RAG ๊ตฌํ˜„(Dense Retrieval) | +| ์ด์˜ˆ์„œ | ๋ฉ”๋ชจ๋ฆฌ/์†๋„ ์ตœ์ ํ™”, ์–‘์žํ™”(BitsAndBytes, GPTQ), ๋ฐ์ดํ„ฐ ์ˆ˜์ง‘, ๋ฐ์ดํ„ฐ ์ •์ œ, RAG ๊ตฌํ˜„(Elastic Search, Reranker, RAFT) | +| ํ™์„ฑ๋ฏผ | EDA(๋ฐ์ดํ„ฐ ์ถœ์ฒ˜ ๊ธฐ๋ฐ˜ ๋ถ„์„), LLM์„ ํ™œ์šฉํ•œ ๋ฐ์ดํ„ฐ ์ฆ๊ฐ• | +| ํ™์„ฑ์žฌ | EDA(๊ตญ์–ด์˜์—ญ๊ณผ ์‚ฌํšŒ์˜์—ญ ์ฐจ์ด ๋ถ„์„), streamlit ์‹œ๊ฐํ™” | + +## ๐Ÿ“ƒ Results +image + +## ๐Ÿ› ๏ธ**Dependencies** +``` +# CUDA Version: 12.2 +# Ubuntu 20.04.6 +# python 3.10.13 + +# Deep Learning +auto_gptq==0.7.1 +bitsandbytes==0.44.1 +evaluate==0.4.3 +huggingface-hub==0.26.2 +numpy==2.0.0 +optimum==1.23.3 +peft==0.5.0 +scikit-learn==1.5.2 +torch==2.5.1 # 2.5.1+cu124 +tqdm==4.67.0 +transformers==4.46.2 +trl==0.12.0 +wandb==0.18.5 + +# RAG +elasticsearch==8.16.0 +konlpy==0.6.0 +rank-bm25==0.2.2 +wikiextractor==3.0.6 +faiss-cpu==1.9.0 # faiss-gpu==1.7.2 + +# Utils +beautifulsoup4==4.12.3 +ipykernel==6.29.5 +ipywidgets==8.1.5 +loguru==0.7.2 +matplotlib==3.9.2 +python-dotenv==1.0.1 +reportlab==4.2.5 +streamlit==1.40.1 +pdfminer.six==20240706 + +# Google Drive API +google-api-python-client==2.151.0 +google-auth-httplib2==0.2.0 +google-auth-oauthlib==1.2.1 + +# Automatically installed dependencies +# pandas==2.2.3 +# pyarrow==18.0.0 +# datasets==3.1.0 +# safetensors==0.4.5 +# scipy==1.14.1 +# tqdm==4.67.0 +# PyYAML==6.0.2 +# requests==2.32.3 + +``` +## ๐Ÿ’พ Usage +1. Setting +``` +$ pip install -r requirements.txt +``` +2. train & inference +```angular2html +$ python3 code/main.py +``` + +## ๐Ÿ“ ํ”„๋กœ์ ํŠธ ๊ตฌ์กฐ +``` +code + โ”ฃ rag + โ”ƒ โ”ฃ data_process + โ”ƒ โ”ƒ โ”ฃ external_data.py + โ”ƒ โ”ƒ โ”— wiki_dump.py + โ”ƒ โ”ฃ README.md + โ”ƒ โ”ฃ __init__.py + โ”ƒ โ”ฃ chunk_data.py + โ”ƒ โ”ฃ dpr_data.py + โ”ƒ โ”ฃ encoder.py + โ”ƒ โ”ฃ index_runner.py + โ”ƒ โ”ฃ indexers.py + โ”ƒ โ”ฃ prepare_dense.py + โ”ƒ โ”ฃ reranker.py + โ”ƒ โ”ฃ retriever.py + โ”ƒ โ”ฃ retriever_bm25.py + โ”ƒ โ”ฃ retriever_elastic.py + โ”ƒ โ”ฃ train.py + โ”ƒ โ”ฃ trainer.py + โ”ƒ โ”— utils.py + โ”ฃ utils + โ”ƒ โ”ฃ __init__.py + โ”ƒ โ”ฃ common.py + โ”ƒ โ”ฃ gdrive_manager.py + โ”ƒ โ”— hf_manager.py + โ”ฃ data_loaders.py + โ”ฃ inference.py + โ”ฃ labeling.py + โ”ฃ main.py + โ”ฃ model.py + โ”ฃ split.py + โ”— trainer.py + data_aug + โ”ฃ add_CoT.py + โ”— aug_philo.py + data_process + โ”ฃ crawling_gichulpass.py + โ”ฃ external_musr.py + โ”ฃ external_race.py + โ”ฃ external_sat_gaokao.py + โ”ฃ pdf_to_txt.py + โ”ฃ process_balance_choices.py + โ”ฃ process_formatting.py + โ”— process_google_translate.py + data_viz + โ”ฃ csv2pdf.py + โ”ฃ labeling.py + โ”— streamlit_app.py +config + โ”ฃ sample + โ”ƒ โ”ฃ config.yaml + โ”ƒ โ”— env-sample.txt + โ”— elastic_setting.json +``` diff --git a/assets/final_result.png b/assets/final_result.png new file mode 100644 index 0000000..153ef8a Binary files /dev/null and b/assets/final_result.png differ diff --git a/assets/private_rank.png b/assets/private_rank.png new file mode 100644 index 0000000..20e1c94 Binary files /dev/null and b/assets/private_rank.png differ diff --git a/assets/public_rank.png b/assets/public_rank.png new file mode 100644 index 0000000..5dfb361 Binary files /dev/null and b/assets/public_rank.png differ diff --git a/code/data_loaders.py b/code/data_loaders.py new file mode 100644 index 0000000..5ab6565 --- /dev/null +++ b/code/data_loaders.py @@ -0,0 +1,385 @@ +from ast import literal_eval +import os +import pickle +from typing import Dict, List + +from datasets import Dataset +from dotenv import load_dotenv +from loguru import logger +import numpy as np +import pandas as pd +from rag import ElasticsearchRetriever, Reranker +from rag.dpr_data import KorQuadDataset +from rag.encoder import KobertBiEncoder +from rag.indexers import DenseFlatIndexer +from rag.retriever import KorDPRRetriever, get_passage_file +from utils import load_config + + +class DataLoader: + def __init__(self, tokenizer, data_config): + self.tokenizer = tokenizer + self.retriever_config = data_config["retriever"] + self.train_path = data_config["train_path"] + self.test_path = data_config["test_path"] + self.processed_train_path = data_config["processed_train_path"] + self.processed_test_path = data_config["processed_test_path"] + self.max_seq_length = data_config["max_seq_length"] + self.test_size = data_config["test_size"] + self.prompt_config = data_config["prompt"] + + def prepare_datasets(self, is_train): + """ํ•™์Šต ๋˜๋Š” ํ…Œ์ŠคํŠธ์šฉ ๋ฐ์ดํ„ฐ์…‹ ์ค€๋น„""" + # prompt ์ „์ฒ˜๋ฆฌ๋œ ๋ฐ์ดํ„ฐ์…‹ ํŒŒ์ผ์ด ์กด์žฌํ•œ๋‹ค๋ฉด ์ด๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค. + processed_df_path = self.processed_train_path if is_train else self.processed_test_path + if os.path.isfile(processed_df_path): + logger.info(f"์ „์ฒ˜๋ฆฌ๋œ ๋ฐ์ดํ„ฐ์…‹์„ ๋ถˆ๋Ÿฌ์˜ต๋‹ˆ๋‹ค: {processed_df_path}") + processed_df = pd.read_csv(processed_df_path, encoding="utf-8") + processed_df["messages"] = processed_df["messages"].apply(literal_eval) + processed_dataset = Dataset.from_pandas(processed_df) + else: + dataset = self._load_data(is_train) + processed_dataset = self._process_dataset(dataset, is_train) + + if is_train: + tokenized_dataset = self._tokenize_dataset(processed_dataset) + splitted_dataset = self._split_dataset(tokenized_dataset) + return splitted_dataset + return processed_dataset + + def _retrieve(self, df): # noqa: C901 + if self.retriever_config["retriever_type"] == "Elasticsearch": + retriever = ElasticsearchRetriever( + index_name=self.retriever_config["index_name"], + ) + elif self.retriever_config["retriever_type"] == "DPR": + # KorDPRRetriever ์‚ฌ์šฉ + try: + model = KobertBiEncoder() # ๋ชจ๋ธ ์ดˆ๊ธฐํ™” + model.load("./rag/output/my_model.pt") # ๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ + logger.debug("Model loaded successfully.") + assert model is not None, "Model is None after loading." + except Exception as e: + logger.debug(f"Error while loading model: {e}") + + try: + valid_dataset = KorQuadDataset("./rag/data/KorQuAD_v1.0_dev.json") # ๋ฐ์ดํ„ฐ์…‹ ์ค€๋น„ + logger.debug("Valid dataset loaded successfully.") + except Exception as e: + logger.debug(f"Error while loading valid dataset: {e}") + + try: + index = DenseFlatIndexer() # ์ธ๋ฑ์Šค ์ค€๋น„ + index.deserialize(path="./rag/2050iter_flat/") + logger.debug("Index loaded successfully.") + assert index is not None, "Index is None after loading." + except Exception as e: + logger.debug(f"Error while loading index: {e}") + + ds_retriever = KorDPRRetriever(model=model, valid_dataset=valid_dataset, index=index) + logger.debug("KorDPRRetriever initialized successfully.") + else: + return [""] * len(df) + + def _combine_text(row): + # NaN ๊ฐ’ ์ฒ˜๋ฆฌ + paragraph = "" if pd.isna(row["paragraph"]) else str(row["paragraph"]) + if pd.isna(row["problems"]): + problems = {"question": "", "choices": []} + else: + problems = row["problems"] + question = str(problems.get("question", "")) + choices = [str(choice) for choice in problems.get("choices", [])] + + if self.retriever_config["query_type"] == "pqc": + return paragraph + " " + question + " " + " ".join(choices) + if self.retriever_config["query_type"] == "pq": + return paragraph + " " + question + if self.retriever_config["query_type"] == "pc": + return paragraph + " " + " ".join(choices) + else: + return paragraph + + top_k = self.retriever_config["top_k"] + threshold = self.retriever_config["threshold"] + query_max_length = self.retriever_config["query_max_length"] + + queries = df.apply(_combine_text, axis=1) + if self.retriever_config["retriever_type"] == "Elasticsearch": + filtered_queries = [(i, q) for i, q in enumerate(queries) if len(q) <= query_max_length] + if not filtered_queries: + return [""] * len(queries) + + indices, valid_queries = zip(*filtered_queries) + retrieve_results = retriever.bulk_retrieve(valid_queries, top_k) + rerank_k = self.retriever_config["rerank_k"] + if rerank_k > 0: + with Reranker() as reranker: + retrieve_results = reranker.rerank(valid_queries, retrieve_results, rerank_k) + # [[{"text":"์•ˆ๋…•ํ•˜์„ธ์š”", "score":0.5}, {"text":"๋ฐ˜๊ฐ‘์Šต๋‹ˆ๋‹ค", "score":0.3},],] + + docs = [""] * len(queries) + for idx, result in zip(indices, retrieve_results): + docs[idx] = " ".join(item["text"] for item in result if item["score"] >= threshold) + docs[idx] = docs[idx][: self.retriever_config["result_max_length"]] + elif self.retriever_config["retriever_type"] == "DPR": # DPR์ธ ๊ฒฝ์šฐ + docs = [] + for query in queries: + passages = ds_retriever.retrieve(query=query, k=top_k) # DPR์œผ๋กœ ๊ฒ€์ƒ‰ + + # passage ๋กœ๋”ฉ ๋ฐ ๊ฒฐํ•ฉ + for idx, (passage, score) in enumerate(passages): + # passage ID์— ํ•ด๋‹นํ•˜๋Š” ํŒŒ์ผ ๊ฒฝ๋กœ ๊ฐ€์ ธ์˜ค๊ธฐ + path = get_passage_file([idx]) + if path: + with open(path, "rb") as f: + passage_dict = pickle.load(f) + docs.append((passage_dict[idx], score)) # passage์™€ score ์ €์žฅ + else: + logger.debug(f"No passage found for ID: {idx}") + + # ๋กœ๊น… ์ถ”๊ฐ€ + logger.info(f"๊ฐ€์—ฐ Query: {query}") + logger.info(f"Rank {idx+1}: Score: {score:.4f}, Passage: {passage}") + + return docs + + def _load_data(self, is_train) -> List[Dict]: + """csv๋ฅผ ์ฝ์–ด์˜ค๊ณ  dictionary ๋ฐฐ์—ด ํ˜•ํƒœ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.""" + file_path = self.train_path if is_train else self.test_path + df = pd.read_csv(file_path) + df["problems"] = df["problems"].apply(literal_eval) + docs = self._retrieve(df) + records = [] + for idx, row in df.iterrows(): + problems = row["problems"] + record = { + "id": row["id"], + "paragraph": row["paragraph"], + "question": problems["question"], + "choices": problems["choices"], + "answer": problems.get("answer", None), + "question_plus": problems.get("question_plus", None), + "document": docs[idx], + } + records.append(record) + logger.info("dataset ๋กœ๋“œ ๋ฐ retrive ์™„๋ฃŒ.") + return records + + def _process_dataset(self, dataset: List[Dict], is_train=True): + """๋ฐ์ดํ„ฐ์— ํ”„๋กฌํ”„ํŠธ ์ ์šฉ""" + + # ๋ฐ์ดํ„ฐ์…‹์„ prompt ์ „์ฒ˜๋ฆฌํ•˜๊ณ  ์ €์žฅํ•ฉ๋‹ˆ๋‹ค. + logger.info("๋ฐ์ดํ„ฐ์…‹ ์ „์ฒ˜๋ฆฌ๋ฅผ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.") + processed_data = [] + for row in dataset: + choices_string = "\n".join([f"{idx + 1} - {choice}" for idx, choice in enumerate(row["choices"])]) + + # start + if row["question_plus"]: + message_start = self.prompt_config["start_with_plus"].format( + paragraph=row["paragraph"], + question=row["question"], + question_plus=row["question_plus"], + choices=choices_string, + ) + else: + message_start = self.prompt_config["start"].format( + paragraph=row["paragraph"], + question=row["question"], + choices=choices_string, + ) + # mid + if row["document"]: + message_mid = self.prompt_config["mid_with_document"].format( + document=row["document"], + ) + else: + message_mid = self.prompt_config["mid"] + # end + message_end = self.prompt_config["end"] + + user_message = message_start + message_mid + message_end + messages = [ + {"role": "system", "content": "์ง€๋ฌธ์„ ์ฝ๊ณ  ์งˆ๋ฌธ์˜ ๋‹ต์„ ๊ตฌํ•˜์„ธ์š”."}, + {"role": "user", "content": user_message}, + ] + + if is_train: + messages.append({"role": "assistant", "content": f"{row['answer']}"}) + + processed_data.append({"id": row["id"], "messages": messages, "label": row["answer"] if is_train else None}) + + processed_df = pd.DataFrame(processed_data) + logger.info("๋ฐ์ดํ„ฐ์…‹ ์ „์ฒ˜๋ฆฌ๊ฐ€ ์™„๋ฃŒ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.") + processed_df_path = self.processed_train_path if is_train else self.processed_test_path + if processed_df_path: + processed_df.to_csv(processed_df_path, index=False, encoding="utf-8") + logger.info("์ „์ฒ˜๋ฆฌ๋œ ๋ฐ์ดํ„ฐ์…‹์ด ์ €์žฅ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.") + return Dataset.from_pandas(processed_df) + + def _tokenize_dataset(self, dataset): + def formatting_prompts_func(example): + output_texts = [] + for i in range(len(example["messages"])): + output_texts.append( + self.tokenizer.apply_chat_template( + example["messages"][i], + tokenize=False, + ) + ) + return output_texts + + def tokenize(element): + outputs = self.tokenizer( + formatting_prompts_func(element), + truncation=False, + padding=False, + return_overflowing_tokens=False, + return_length=False, + ) + return { + "input_ids": outputs["input_ids"], + "attention_mask": outputs["attention_mask"], + } + + tokenized_dataset = dataset.map( + tokenize, + remove_columns=list(dataset.features), + batched=True, + num_proc=4, + load_from_cache_file=True, + desc="Tokenizing", + ) + + # ํ† ํฐ ๊ธธ์ด๊ฐ€ max_seq_length๋ฅผ ์ดˆ๊ณผํ•˜๋Š” ๋ฐ์ดํ„ฐ ํ•„ํ„ฐ๋ง + logger.info(f"dataset length: {len(tokenized_dataset)}") + tokenized_dataset = tokenized_dataset.filter(lambda x: len(x["input_ids"]) <= self.max_seq_length) + logger.info(f"filtered dataset length: {len(tokenized_dataset)}") + + return tokenized_dataset + + def _split_dataset(self, dataset): + split_dataset = dataset.train_test_split(test_size=self.test_size, seed=42) + train_dataset = split_dataset["train"] + eval_dataset = split_dataset["test"] + + logger.debug(self.tokenizer.decode(train_dataset[0]["input_ids"], skip_special_tokens=True)) + train_dataset_token_lengths = [len(train_dataset[i]["input_ids"]) for i in range(len(train_dataset))] + logger.info(f"max token length: {max(train_dataset_token_lengths)}") + logger.info(f"min token length: {min(train_dataset_token_lengths)}") + logger.info(f"avg token length: {np.mean(train_dataset_token_lengths)}") + + return train_dataset, eval_dataset + + +if __name__ == "__main__": + config_folder = os.path.join(os.path.dirname(__file__), "..", "config/") + load_dotenv(os.path.join(config_folder, ".env")) + config = load_config() + data_config = config["data"] + + def _retrieve(retriever_config, df): # noqa: C901 + if retriever_config["retriever_type"] == "Elasticsearch": + retriever = ElasticsearchRetriever( + index_name=retriever_config["index_name"], + ) + elif retriever_config["retriever_type"] == "BM25": + raise NotImplementedError("BM25๋Š” ๋” ์ด์ƒ ์ง€์›ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. Elasticsearch๋ฅผ ์‚ฌ์šฉํ•ด์ฃผ์„ธ์š”...") + + elif retriever_config["retriever_type"] == "DPR": + # KorDPRRetriever ์‚ฌ์šฉ + try: + model = KobertBiEncoder() # ๋ชจ๋ธ ์ดˆ๊ธฐํ™” + model.load("./rag/output/my_model.pt") # ๋ชจ๋ธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ + logger.debug("Model loaded successfully.") + assert model is not None, "Model is None after loading." + except Exception as e: + logger.debug(f"Error while loading model: {e}") + + try: + valid_dataset = KorQuadDataset("./rag/data/KorQuAD_v1.0_dev.json") # ๋ฐ์ดํ„ฐ์…‹ ์ค€๋น„ + logger.debug("Valid dataset loaded successfully.") + except Exception as e: + logger.debug(f"Error while loading valid dataset: {e}") + + try: + index = DenseFlatIndexer() # ์ธ๋ฑ์Šค ์ค€๋น„ + index.deserialize(path="./rag/2050iter_flat/") + logger.debug("Index loaded successfully.") + assert index is not None, "Index is None after loading." + except Exception as e: + logger.debug(f"Error while loading index: {e}") + + ds_retriever = KorDPRRetriever(model=model, valid_dataset=valid_dataset, index=index) + logger.debug("KorDPRRetriever initialized successfully.") + + else: + return [""] * len(df) + + def _combine_text(row): + if retriever_config["query_type"] == "pqc": + return row["paragraph"] + " " + row["problems"]["question"] + " " + " ".join(row["problems"]["choices"]) + if retriever_config["query_type"] == "pq": + return row["paragraph"] + " " + row["problems"]["question"] + if retriever_config["query_type"] == "pc": + return row["paragraph"] + " " + " ".join(row["problems"]["choices"]) + else: + return row["paragraph"] + + top_k = retriever_config["top_k"] + threshold = retriever_config["threshold"] + query_max_length = retriever_config["query_max_length"] + + queries = df.apply(_combine_text, axis=1) + if retriever_config["retriever_type"] == "Elasticsearch": + filtered_queries = [(i, q) for i, q in enumerate(queries) if len(q) <= query_max_length] + if not filtered_queries: + return [""] * len(queries) + + indices, valid_queries = zip(*filtered_queries) + retrieve_results = retriever.bulk_retrieve(valid_queries, top_k) + rerank_k = retriever_config["rerank_k"] + if rerank_k > 0: + with Reranker() as reranker: + retrieve_results = reranker.rerank(valid_queries, retrieve_results, rerank_k) + # [[{"text":"์•ˆ๋…•ํ•˜์„ธ์š”", "score":0.5}, {"text":"๋ฐ˜๊ฐ‘์Šต๋‹ˆ๋‹ค", "score":0.3},],] + + docs = [""] * len(queries) + for idx, result in zip(indices, retrieve_results): + docs[idx] = " ".join( + f"[{item['score']}]: {item['text']}" for item in result if item["score"] >= threshold + ) + + elif retriever_config["retriever_type"] == "DPR": # DPR์ธ ๊ฒฝ์šฐ + docs = [] + for query in queries: + passages = ds_retriever.retrieve(query=query, k=top_k) # DPR์œผ๋กœ ๊ฒ€์ƒ‰ + + # passage ๋กœ๋”ฉ ๋ฐ ๊ฒฐํ•ฉ + for idx, (passage, score) in enumerate(passages): + # passage ID์— ํ•ด๋‹นํ•˜๋Š” ํŒŒ์ผ ๊ฒฝ๋กœ ๊ฐ€์ ธ์˜ค๊ธฐ + path = get_passage_file([idx]) + if path: + with open(path, "rb") as f: + passage_dict = pickle.load(f) + docs.append((passage_dict[idx], score)) # passage์™€ score ์ €์žฅ + else: + logger.debug(f"No passage found for ID: {idx}") + + # ๋กœ๊น… ์ถ”๊ฐ€ + logger.info(f"Query: {query}") + logger.info(f"Rank {idx+1}: Score: {score:.4f}, Passage: {passage}") + return docs + + def load_and_save(retriever_config, file_path) -> List[Dict]: + """csv๋ฅผ ์ฝ์–ด์˜ค๊ณ  dictionary ๋ฐฐ์—ด ํ˜•ํƒœ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.""" + df = pd.read_csv(file_path) + df["problems"] = df["problems"].apply(literal_eval) + docs = _retrieve(retriever_config, df) + df["documents"] = docs + df.to_csv(file_path.replace(".csv", "_retrieve.csv"), index=False) + logger.debug("retrieve ๊ฒฐ๊ณผ๊ฐ€ csv๋กœ ์ €์žฅ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.") + + load_and_save(data_config["retriever"], data_config["train_path"]) + load_and_save(data_config["retriever"], data_config["test_path"]) diff --git a/code/inference.py b/code/inference.py new file mode 100644 index 0000000..58767f5 --- /dev/null +++ b/code/inference.py @@ -0,0 +1,48 @@ +from loguru import logger +import numpy as np +import pandas as pd +import torch +from tqdm import tqdm + + +class InferenceModel: + def __init__(self, inference_config, model, tokenizer, test_dataset): + self.inference_config = inference_config + self.model = model + self.tokenizer = tokenizer + self.test_dataset = test_dataset + self.pred_choices_map = {0: "1", 1: "2", 2: "3", 3: "4", 4: "5"} + + def run_inference(self): + if not self.inference_config["do_test"]: + logger.info("์ถ”๋ก  ๋‹จ๊ณ„๋ฅผ ์ƒ๋žตํ•ฉ๋‹ˆ๋‹ค. inference do_test ์„ค์ •์„ ํ™•์ธํ•˜์„ธ์š”.") + return + + results = self._inference(self.test_dataset) + return self._save_results(results) + + def _inference(self, test_dataset): + infer_results = [] + self.model.config.use_cache = True + self.model.eval() + + with torch.inference_mode(): + for example in tqdm(test_dataset): + outputs = self.model( + self.tokenizer.apply_chat_template( + example["messages"], tokenize=True, add_generation_prompt=True, return_tensors="pt" + ).to("cuda") + ) + + logits = outputs.logits[:, -1].flatten().cpu() + target_logits = [logits[self.tokenizer.vocab[str(i + 1)]] for i in range(5)] # ์„ ํƒ์ง€๋Š” ํ•ญ์ƒ 5๊ฐœ + probs = torch.nn.functional.softmax(torch.tensor(target_logits, dtype=torch.float32), dim=-1) + predict_value = self.pred_choices_map[np.argmax(probs.detach().cpu().numpy())] + + infer_results.append({"id": example["id"], "answer": predict_value}) + + return infer_results + + def _save_results(self, results): + logger.info(self.inference_config["output_path"]) + pd.DataFrame(results).to_csv(self.inference_config["output_path"], index=False) diff --git a/code/labeling.py b/code/labeling.py new file mode 100644 index 0000000..e9956c7 --- /dev/null +++ b/code/labeling.py @@ -0,0 +1,119 @@ +from cleanlab.classification import CleanLearning +from cleanlab.filter import find_label_issues +from loguru import logger +import numpy as np +import pandas as pd +from sklearn.cluster import KMeans +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.model_selection import StratifiedKFold +from xgboost import XGBClassifier + + +# Pandas ์ถœ๋ ฅ ์„ค์ • +pd.set_option("display.max_columns", None) +pd.set_option("display.max_rows", None) +pd.set_option("display.max_colwidth", None) + + +def create_initial_labels(input_file, output_file, num_clusters=2): + """TF-IDF์™€ K-means๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ดˆ๊ธฐ ๋ผ๋ฒจ์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.""" + df = pd.read_csv(input_file) + df.dropna(subset=["paragraph", "problems"], inplace=True) + df["combined_text"] = df["paragraph"] + " " + df["problems"] + + vectorizer = TfidfVectorizer() + X = vectorizer.fit_transform(df["combined_text"]) + + kmeans = KMeans(n_clusters=num_clusters, random_state=42) + kmeans.fit(X) + + df["target"] = kmeans.labels_ + final_columns = ["id", "paragraph", "problems", "question_plus", "target"] + df[final_columns].to_csv(output_file, index=False) + logger.info(f"์ดˆ๊ธฐ ๋ผ๋ฒจ๋ง์ด ์™„๋ฃŒ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ๊ฒฐ๊ณผ๊ฐ€ {output_file}์— ์ €์žฅ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.") + + +def load_and_preprocess_data(file_path): + """๋ฐ์ดํ„ฐ๋ฅผ ๋กœ๋“œํ•˜๊ณ  ์ „์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค.""" + df = pd.read_csv(file_path) + if "target" not in df.columns: + raise ValueError("๋ฐ์ดํ„ฐ์…‹์— 'target' ์—ด์ด ์—†์Šต๋‹ˆ๋‹ค.") + + X = df[["paragraph", "problems"]].astype(str).agg(" ".join, axis=1) + y = df["target"].astype(int) + return df, X, y + + +def vectorize_text(X, max_features=5000): + """ํ…์ŠคํŠธ ๋ฐ์ดํ„ฐ๋ฅผ ๋ฒกํ„ฐํ™”ํ•ฉ๋‹ˆ๋‹ค.""" + vectorizer = TfidfVectorizer(max_features=max_features) + return vectorizer.fit_transform(X) + + +def train_and_predict(X_vectorized, y, n_splits=5): + """๋ชจ๋ธ์„ ํ›ˆ๋ จํ•˜๊ณ  ์˜ˆ์ธก ํ™•๋ฅ ์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.""" + base_model = XGBClassifier(eval_metric="mlogloss", n_estimators=100) + model = CleanLearning(base_model) + skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42) + pred_probs = np.zeros((len(y), len(np.unique(y)))) + + for fold, (train_index, val_index) in enumerate(skf.split(X_vectorized, y), 1): + logger.info(f"Fold {fold}/{n_splits}") + X_train, X_val = X_vectorized[train_index], X_vectorized[val_index] + y_train, _ = y[train_index], y[val_index] + model.fit(X_train, y_train) + pred_probs[val_index] = model.predict_proba(X_val) + + return pred_probs + + +def find_and_update_label_issues(df, y, pred_probs): + """๋ ˆ์ด๋ธ” ์ด์Šˆ๋ฅผ ์ฐพ๊ณ  ๋ฐ์ดํ„ฐํ”„๋ ˆ์ž„์„ ์—…๋ฐ์ดํŠธํ•ฉ๋‹ˆ๋‹ค.""" + label_issues = find_label_issues(labels=y, pred_probs=pred_probs, return_indices_ranked_by="self_confidence") + df["is_label_issue"] = False + df.loc[label_issues, "is_label_issue"] = True + df["suggested_label"] = np.argmax(pred_probs, axis=1) + return df + + +def save_and_print_results(df, output_file): + """๊ฒฐ๊ณผ๋ฅผ ์ €์žฅํ•˜๊ณ  ์ถœ๋ ฅํ•ฉ๋‹ˆ๋‹ค.""" + final_columns = ["id", "paragraph", "problems", "question_plus", "target", "suggested_label", "is_label_issue"] + df[final_columns].to_csv(output_file, index=False) + + logger.info("\nID์™€ ์ œ์•ˆ๋œ ๋ ˆ์ด๋ธ”:") + logger.info(df[["id", "suggested_label"]].to_string(index=False)) + + logger.info("\n๋ ˆ์ด๋ธ” ์ด์Šˆ ํ†ต๊ณ„:") + logger.info(df["is_label_issue"].value_counts(normalize=True)) + + logger.info("\n์›๋ž˜ ๋ ˆ์ด๋ธ”๊ณผ ์ œ์•ˆ๋œ ๋ ˆ์ด๋ธ” ๋น„๊ต:") + logger.info(pd.crosstab(df["target"], df["suggested_label"])) + + +def main(): + initial_input_file = "../data/train.csv" + initial_output_file = "../data/output_with_labels.csv" + final_output_file = "../data/cleaned_output_with_labels_CL.csv" + + # ์ดˆ๊ธฐ ๋ผ๋ฒจ๋ง ์ˆ˜ํ–‰ + create_initial_labels(initial_input_file, initial_output_file) + + # ๋ฐ์ดํ„ฐ ๋กœ๋“œ ๋ฐ ์ „์ฒ˜๋ฆฌ + df, X, y = load_and_preprocess_data(initial_output_file) + + # ํ…์ŠคํŠธ ๋ฒกํ„ฐํ™” + X_vectorized = vectorize_text(X) + + # ๋ชจ๋ธ ํ›ˆ๋ จ ๋ฐ ์˜ˆ์ธก + pred_probs = train_and_predict(X_vectorized, y) + + # ๋ ˆ์ด๋ธ” ์ด์Šˆ ์ฐพ๊ธฐ ๋ฐ ์—…๋ฐ์ดํŠธ + df = find_and_update_label_issues(df, y, pred_probs) + + # ๊ฒฐ๊ณผ ์ €์žฅ ๋ฐ ์ถœ๋ ฅ + save_and_print_results(df, final_output_file) + + +if __name__ == "__main__": + main() diff --git a/code/main.py b/code/main.py new file mode 100644 index 0000000..a8733ec --- /dev/null +++ b/code/main.py @@ -0,0 +1,84 @@ +import os + +from data_loaders import DataLoader +from inference import InferenceModel +from loguru import logger +from model import ModelHandler +from trainer import CustomTrainer +from utils import ( + GoogleDriveManager, + create_experiment_filename, + load_config, + load_env_file, + log_config, + set_logger, + set_seed, +) +import wandb + + +def main(): + # env, config, log, seed ์„ค์ • + load_env_file() + config = load_config() + set_logger(log_file=config["log"]["file"], log_level=config["log"]["level"]) + set_seed() + + # wandb ์„ค์ • + exp_name = create_experiment_filename(config) + wandb.init( + config=config, + project=config["wandb"]["project"], + entity=config["wandb"]["entity"], + name=exp_name, + ) + + # wandb ์‹คํ—˜๋ช…์œผ๋กœ config ๊ฐฑ์‹  + config["training"]["run_name"] = exp_name + config["inference"]["output_path"] = os.path.join(config["inference"]["output_path"], exp_name + "_output.csv") + log_config(config) + + try: + # ๋ชจ๋ธ ๋ฐ ํ† ํฌ๋‚˜์ด์ € ์„ค์ • + model_handler = ModelHandler(config["model"]) + model, tokenizer = model_handler.setup() + + # ํ•™์Šต์šฉ ๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ + data_processor = DataLoader(tokenizer, config["data"]) + train_dataset, eval_dataset = data_processor.prepare_datasets(is_train=True) + test_dataset = data_processor.prepare_datasets(is_train=False) + + # ํ•™์Šต + trainer = CustomTrainer( + training_config=config["training"], + model=model, + tokenizer=tokenizer, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + ) + trained_model = trainer.train() + + # ์ถ”๋ก  + inferencer = InferenceModel( + inference_config=config["inference"], + model=trained_model, + tokenizer=tokenizer, + test_dataset=test_dataset, + ) + inferencer.run_inference() + + except Exception as e: + logger.exception(f"Error occurred: {e}") + wandb.finish(exit_code=1) + else: + logger.info("Upload output & config to GDrive...") + gdrive_manager = GoogleDriveManager() + gdrive_manager.upload_exp( + config["exp"]["username"], + config["inference"]["output_path"], + ) + wandb.finish() + + +if __name__ == "__main__": + main() diff --git a/code/model.py b/code/model.py new file mode 100644 index 0000000..6d7fd69 --- /dev/null +++ b/code/model.py @@ -0,0 +1,60 @@ +from loguru import logger +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig + + +class ModelHandler: + def __init__(self, model_config): + self.base_model = model_config["base_model"] + self.model_config = model_config["model"] + self.tokenizer_config = model_config["tokenizer"] + + def setup(self): + model = self._load_model() + tokenizer = self._load_tokenizer() + return model, tokenizer + + def _load_model(self): + torch_dtype = getattr(torch, self.model_config["torch_dtype"]) + base_kwargs = {"trust_remote_code": True, "low_cpu_mem_usage": self.model_config["low_cpu_mem_usage"]} + + if self.model_config["quantization"] == "BitsAndBytes": + bits = self.model_config["bits"] + if bits == 8: + quantization_config = BitsAndBytesConfig( + load_in_8bit=True, + bnb_8bit_use_double_quant=self.model_config["use_double_quant"], + bnb_8bit_compute_dtype=torch_dtype, + ) + elif bits == 4: + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=self.model_config["use_double_quant"], + bnb_4bit_compute_dtype=torch_dtype, + ) + else: + raise ValueError(f"Unsupported bits value: {bits}") + + base_kwargs["quantization_config"] = quantization_config + elif self.model_config["quantization"] == "auto": + base_kwargs["torch_dtype"] = "auto" + base_kwargs["device_map"] = "auto" + else: + base_kwargs["torch_dtype"] = torch_dtype + + logger.debug(f"base_kwargs: {base_kwargs}") + model = AutoModelForCausalLM.from_pretrained(self.base_model, **base_kwargs) + model.config.use_cache = self.model_config["use_cache"] + return model + + def _load_tokenizer(self): + tokenizer = AutoTokenizer.from_pretrained(self.base_model, trust_remote_code=True) + self._setup_tokenizer(tokenizer) + return tokenizer + + def _setup_tokenizer(self, tokenizer): + tokenizer.chat_template = self.tokenizer_config["chat_template"] + tokenizer.pad_token = tokenizer.eos_token + tokenizer.pad_token_id = tokenizer.eos_token_id + tokenizer.padding_side = self.tokenizer_config["padding_side"] diff --git a/code/rag/README.md b/code/rag/README.md new file mode 100644 index 0000000..4a54cc8 --- /dev/null +++ b/code/rag/README.md @@ -0,0 +1,34 @@ +# Dense Retriever ์‚ฌ์šฉ ๊ฐ€์ด๋“œ + +์ด ๊ฐ€์ด๋“œ๋Š” Dense Retriever๋ฅผ ์„ค์ •ํ•˜๊ณ  ์‚ฌ์šฉํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์„ค๋ช…ํ•ฉ๋‹ˆ๋‹ค. + +## ์ค€๋น„ ๋‹จ๊ณ„ + +1. **์œ„ํ‚คํ”ผ๋””์•„ ๋คํ”„ ํŒŒ์ผ ์ค€๋น„** + - `rag` ํด๋” ๋‚ด์— ์œ„ํ‚คํ”ผ๋””์•„ ๋คํ”„ ํŒŒ์ผ์„ ๋‹ค์šด๋กœ๋“œํ•˜๊ณ  ์••์ถ•์„ ํ•ด์ œํ•ฉ๋‹ˆ๋‹ค. + - `text` ํด๋” ๋‚ด์— `AA`, `AB`, `AC` ํด๋”๊ฐ€ ์กด์žฌํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. + - ๊ฐ ํด๋” ์•ˆ์— `wiki_`๋กœ ์‹œ์ž‘ํ•˜๋Š” ํŒŒ์ผ๋“ค์ด ์žˆ์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. +2. **KorQuAD_v1.0 ๋ฐ์ดํ„ฐ์…‹ ์ค€๋น„** + - `data` ํด๋” ๋‚ด์— KorQuAD_v1.0_dev, KorQuAD_v1.0_train ํŒŒ์ผ์„ ์ค€๋น„ํ•ด์•ผํ•ฉ๋‹ˆ๋‹ค. + - https://korquad.github.io/category/1.0_KOR.html +3. **๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ** + - `prepare_dense.py` ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค. + - ์‹คํ–‰ ํ›„ ๋‹ค์Œ ํŒŒ์ผ๋“ค์ด ์ƒ์„ฑ๋˜์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค: + - `preproccessed_passages/0-XXXX.p,XXXX-XXXX.p...` + - `titled_passage_map.p` + - `2050iter_flat/index_meta.dpr,index.dpr` + - ์ƒ์„ฑ๋œ ํŒŒ์ผ์„ ์‚ฌ์šฉํ•˜์—ฌ ์˜ˆ์‹œ ์ฟผ๋ฆฌ์— ๋Œ€ํ•œ ์ ์ ˆํ•œ ๋ฌธ์„œ ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋ฅผ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค. +## ์‚ฌ์šฉ ๋ฐฉ๋ฒ• + +1. **์„ค์ • ํŒŒ์ผ ์ˆ˜์ •** + - `config` ํŒŒ์ผ์—์„œ `retriever_type`์„ `"DPR"`๋กœ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค. + +2. **์‹คํ–‰** + - ๋‹ค์Œ ๋ช…๋ น์–ด๋ฅผ ์‹คํ–‰ํ•˜์—ฌ Dense Retriever๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค: + ```bash + python main.py + ``` + +## ์ฃผ์˜์‚ฌํ•ญ +- ์ด ์ฝ”๋“œ๋Š” https://github.com/TmaxEdu/KorDPR๋ฅผ ์ฐธ๊ณ ํ•˜์—ฌ ์ž‘์„ฑ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. +- ์œ„์˜ ๋‹จ๊ณ„๋“ค์„ ์ˆœ์„œ๋Œ€๋กœ ์ง„ํ–‰ํ•ด์•ผ Dense Retriever๊ฐ€ ์ •์ƒ์ ์œผ๋กœ ์ž‘๋™ํ•ฉ๋‹ˆ๋‹ค. diff --git a/code/rag/__init__.py b/code/rag/__init__.py new file mode 100644 index 0000000..ca0ab1e --- /dev/null +++ b/code/rag/__init__.py @@ -0,0 +1,18 @@ +# # __init__.py ํŒŒ์ผ ๋‚ด์—์„œ ํ•„์š”ํ•œ ๋ชจ๋“ˆ๋“ค์„ ์ž„ํฌํŠธ +from .chunk_data import DataChunk, save_orig_passage, save_title_index_map +from .dpr_data import KorQuadDataset, KorQuadSampler, korquad_collator +from .encoder import KobertBiEncoder + +# from .retriever_dense import DenseRetriever # ์ด ๋ถ€๋ถ„๋„ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค. +from .reranker import Reranker + +# #from .utils import get_wiki_filepath, wiki_worker_init # ๋ณ€๊ฒฝ ์—†์Œ +# # import transformers +# # # ์™ธ๋ถ€ ์Šคํฌ๋ฆฝํŠธ์—์„œ IndexRunner ์ž„ํฌํŠธ +# from .index_runner import IndexRunner +# from .retriever import KorDPRRetriever # retriever.py์—์„œ ๊ฐ€์ ธ์˜ค๊ธฐ +# from .indexers import DenseFlatIndexer +# # # ์ถ”๊ฐ€๋œ ๋ถ€๋ถ„ +from .retriever_bm25 import BM25Retriever +from .retriever_elastic import ElasticsearchRetriever +from .trainer import Trainer diff --git a/code/rag/chunk_data.py b/code/rag/chunk_data.py new file mode 100644 index 0000000..3bfe86f --- /dev/null +++ b/code/rag/chunk_data.py @@ -0,0 +1,111 @@ +from collections import defaultdict +from glob import glob +import logging +import os +import pickle + +from tqdm import tqdm +from transformers import AutoTokenizer + + +os.makedirs("logs", exist_ok=True) +logging.basicConfig( + filename="logs/log.log", + level=logging.DEBUG, + format="[%(asctime)s | %(funcName)s @ %(pathname)s] %(message)s", +) +logger = logging.getLogger() + + +class DataChunk: + """์ธํ’‹ text๋ฅผ tokenizingํ•œ ๋’ค์— ์ฃผ์–ด์ง„ ๊ธธ์ด๋กœ chunking ํ•ด์„œ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค. + ์ด๋•Œ ํ•˜๋‚˜์˜ chunk(context, index ๋‹จ์œ„)๋Š” ํ•˜๋‚˜์˜ article์—๋งŒ ์†ํ•ด์žˆ์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.""" + + def __init__(self, chunk_size=100): + self.chunk_size = chunk_size + self.tokenizer = AutoTokenizer.from_pretrained("monologg/kobert", trust_remote_code=True) + + def chunk(self, input_file): + logger.info(f"Processing file: {input_file}") + with open(input_file, "rt", encoding="utf8") as f: + input_txt = f.read().strip() + input_txt = input_txt.split("") + chunk_list = [] + orig_text = [] + for art in input_txt: + art = art.strip() + if not art: + logger.debug("Article is empty, passing") + continue + title = art.split("\n")[0].strip(">").split("title=")[1].strip('"') + text = "\n".join(art.split("\n")[2:]).strip() + + logger.debug(f"Processing article: {title}") + + encoded_title = self.tokenizer.encode(title, add_special_tokens=True) + encoded_txt = self.tokenizer.encode(text, add_special_tokens=True) + if len(encoded_txt) < 5: + logger.debug(f"Title {title} has <5 subwords in its article, passing") + continue + + for start_idx in range(0, len(encoded_txt), self.chunk_size): + end_idx = min(len(encoded_txt), start_idx + self.chunk_size) + chunk = encoded_title + encoded_txt[start_idx:end_idx] + orig_text.append(self.tokenizer.decode(chunk)) + chunk_list.append(chunk) + + logger.info(f"Processed {len(orig_text)} chunks from {input_file}.") + return orig_text, chunk_list + + +def save_orig_passage(input_path="text", passage_path="processed_passages", chunk_size=100): + os.makedirs(passage_path, exist_ok=True) + app = DataChunk(chunk_size=chunk_size) + idx = 0 + for path in tqdm(glob(f"{input_path}/*/wiki_*")): + ret, _ = app.chunk(path) + logger.info(f"Processed {len(ret)} chunks from {path}.") # ์ถ”๊ฐ€๋œ ๋กœ๊ทธ + if len(ret) > 0: # ์ฒญํฌ๊ฐ€ ์žˆ๋Š” ๊ฒฝ์šฐ์—๋งŒ ์ €์žฅ + to_save = {idx + i: ret[i] for i in range(len(ret))} + with open(f"{passage_path}/{idx}-{idx+len(ret)-1}.p", "wb") as f: + pickle.dump(to_save, f) + idx += len(ret) + + +def save_title_index_map(index_path="title_passage_map.p", source_passage_path="processed_passages"): + logging.getLogger() + logger.debug(f"Looking for files in {source_passage_path}") + files = glob(f"{source_passage_path}/*") + logger.debug(f"Found {len(files)} files") + + title_id_map = defaultdict(list) + for f in tqdm(files): + logger.debug(f"Processing file: {f}") + with open(f, "rb") as _f: + id_passage_map = pickle.load(_f) + + # ๋กœ๊ทธ ์ถ”๊ฐ€: id_passage_map์˜ ํ˜•์‹ ๋ฐ ๋‚ด์šฉ ํ™•์ธ + logger.debug(f"Loaded {len(id_passage_map)} passages from {f}") + logger.debug(f"Sample passage: {list(id_passage_map.items())[:5]}") # ์ฒซ 5๊ฐœ ํ•ญ๋ชฉ ์ถœ๋ ฅ + + for id, passage in id_passage_map.items(): + parts = passage.split("[SEP]") + if len(parts) > 1: + title = parts[0].split("[CLS]")[1].strip() + title_id_map[title].append(id) + else: + logger.debug(f"Unexpected passage format in file {f}, id {id}") + + logger.debug(f"Processed {len(id_passage_map)} passages from {f}...") + + logger.debug(f"Total unique titles: {len(title_id_map)}") + + with open(index_path, "wb") as f: + pickle.dump(title_id_map, f) + + logger.debug(f"Finished saving title_index_mapping at {index_path}!") + + +# if __name__ == "__main__": +# save_orig_passage() +# save_title_index_map() diff --git a/code/rag/data_process/external_data.py b/code/rag/data_process/external_data.py new file mode 100644 index 0000000..e393476 --- /dev/null +++ b/code/rag/data_process/external_data.py @@ -0,0 +1,198 @@ +import json +import os +from pathlib import Path +import re +import urllib.request + +from loguru import logger + + +def preprocess_text(text): + # ํ•œ๊ธ€, ์ˆซ์ž, ํŠน์ˆ˜๋ฌธ์ž, ๊ณต๋ฐฑ๋งŒ ๋‚จ๊ธฐ๊ณ  ๋‚˜๋จธ์ง€ ์ œ๊ฑฐ + text = re.sub(r"\n", " ", text) + text = re.sub(r"\\n", " ", text) + text = re.sub(r"#", " ", text) + text = re.sub(r"\s+", " ", text).strip() + text = re.sub(r'[^ใ„ฑ-ใ…Ž๊ฐ€-ํžฃ0-9!"#%&\'(),-./:;<=>?@[\]^_`{|}~\s]', "", text) + + # ๋‚ด์šฉ์ด ๋นˆ ๊ด„ํ˜ธ ์ œ๊ฑฐ + pattern = r"\(\s*\)" + while re.search(pattern, text): + text = re.sub(pattern, "", text) + + return text + + +def process_json_array(json_data): + # text ํ•„๋“œ ์ „์ฒ˜๋ฆฌ + if "text" in json_data: + json_data["text"] = preprocess_text(json_data["text"]) + + # title ํ•„๋“œ ์ „์ฒ˜๋ฆฌ + if "title" in json_data: + json_data["title"] = preprocess_text(json_data["title"]) + + return json_data + + +def process_json_file(json_filename): + with open(json_filename, "r", encoding="utf-8") as f: + docs = json.load(f) + + processed_docs = [process_json_array(item) for item in docs] + + # ๋””๋ ‰ํ† ๋ฆฌ์™€ ํŒŒ์ผ๋ช… ๋ถ„๋ฆฌ ํ›„ ํŒŒ์ผ๋ช…์—๋งŒ 'processed_' ์ถ”๊ฐ€ + directory, filename = os.path.split(json_filename) + output_path = os.path.join(directory, "processed_" + filename) + with open(output_path, "w", encoding="utf-8") as f: + json.dump(processed_docs, f, ensure_ascii=False, indent=2) + + +def _dump_wiki(data_path: str = "../data"): + """ + ์œ„ํ‚คํ”ผ๋””์•„ ๋คํ”„๋ฅผ ๋‹ค์šด๋กœ๋“œํ•˜๊ณ  ์ถ”์ถœํ•˜๋Š” ํ•จ์ˆ˜ + """ + dump_filename = "kowiki-latest-pages-articles.xml.bz2" + dump_path = os.path.join(data_path, dump_filename) + wiki_url = f"https://dumps.wikimedia.org/kowiki/latest/{dump_filename}" + + # wget https://dumps.wikimedia.org/kowiki/latest/kowiki-latest-pages-articles.xml.bz2 + if not os.path.exists(dump_path): + logger.debug(f"์œ„ํ‚คํ”ผ๋””์•„ ๋คํ”„๋ฅผ ๋‹ค์šด๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค: {wiki_url}") + urllib.request.urlretrieve(wiki_url, dump_path) + logger.debug(f"๋‹ค์šด๋กœ๋“œ ์™„๋ฃŒ: {dump_path}") + + # python -m wikiextractor.WikiExtractor kowiki-latest-pages-articles.xml.bz2 + extract_dir = os.path.join(data_path, "text") + if not os.path.exists(extract_dir): + logger.debug("WikiExtractor๋กœ ๋คํ”„ ํŒŒ์ผ์„ ์ถ”์ถœํ•ฉ๋‹ˆ๋‹ค...") + os.system(f"python -m wikiextractor.WikiExtractor {dump_path} -o {extract_dir}") + logger.debug("์ถ”์ถœ ์™„๋ฃŒ") + + def _get_filename_list(dirname): + filepaths = [] + for root, dirs, files in os.walk(dirname): + for file in files: + filepath = os.path.join(root, file) + if re.match(r"wiki_[0-9][0-9]", file): + filepaths.append(filepath) + return sorted(filepaths) + + filepaths = _get_filename_list(extract_dir) + output_path = os.path.join(data_path, "wiki_dump.txt") + + # ํŒŒ์ผ ๋‚ด์šฉ ์ฝ๊ธฐ + all_text = "" + for filepath in filepaths: + with open(filepath, "r", encoding="utf-8") as f: + all_text += f.read() + "\n" + + # ์ „์ฒด ํ…์ŠคํŠธ๋ฅผ ํ•˜๋‚˜์˜ ํŒŒ์ผ๋กœ ์ €์žฅ + with open(output_path, "w", encoding="utf-8") as f: + f.write(all_text) + + logger.debug(f"์ด {len(filepaths)}๊ฐœ์˜ ํŒŒ์ผ์„ ์ฒ˜๋ฆฌํ–ˆ์Šต๋‹ˆ๋‹ค.") + logger.debug(f"๋ชจ๋“  ๋‚ด์šฉ์ด {output_path} ํŒŒ์ผ์— ์ €์žฅ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.") + + +def _parse_wiki_dump(file_path: str = "../data/wiki_dump.txt"): + """ + ์œ„ํ‚คํ”ผ๋””์•„ ๋คํ”„ ํŒŒ์ผ์„ JSON ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ํ•จ์ˆ˜ + """ + documents = [] + current_doc = "" + doc_id = None + title = None + + with open(file_path, "r", encoding="utf-8") as f: + for line in f: + # ์ƒˆ๋กœ์šด ๋ฌธ์„œ ์‹œ์ž‘ + if line.startswith(""): + if current_doc.strip(): # ๋นˆ ๋ฌธ์„œ๊ฐ€ ์•„๋‹Œ ๊ฒฝ์šฐ๋งŒ ์ถ”๊ฐ€ + documents.append({"id": doc_id, "title": title, "text": current_doc.strip()}) + # ๋ฌธ์„œ ๋‚ด์šฉ + else: + current_doc += line + + # JSON ํŒŒ์ผ๋กœ ์ €์žฅ + output_path = file_path.replace(".txt", ".json") + with open(output_path, "w", encoding="utf-8") as json_file: + json.dump(documents, json_file, ensure_ascii=False, indent=4) + + logger.debug(f"JSON ํŒŒ์ผ์ด ์ƒ์„ฑ๋˜์—ˆ์Šต๋‹ˆ๋‹ค: {output_path}") + logger.debug(f"์ด {len(documents)}๊ฐœ์˜ ๋ฌธ์„œ๊ฐ€ ์ฒ˜๋ฆฌ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.") + return documents + + +def wikipedia(): + """ + ์œ„ํ‚คํ”ผ๋””์•„ ํ•œ๊ตญ์–ด ๋คํ”„ ๋ฌธ์„œ๋ฅผ ๊ฐ€์ ธ์˜ค๊ณ  ํŒŒ์‹ฑํ•˜์—ฌ ํ•˜๋‚˜์˜ JSON ํŒŒ์ผ ์ƒ์„ฑ + """ + _dump_wiki() + _parse_wiki_dump() + + +def ai_hub_news_corpus(input_dir: str, output_file: str): + """ + ๋Œ€๊ทœ๋ชจ ์›น๋ฐ์ดํ„ฐ ๊ธฐ๋ฐ˜ ํ•œ๊ตญ์–ด ๋ง๋ญ‰์น˜ ๋ฐ์ดํ„ฐ + \n https://www.aihub.or.kr/aihubdata/data/view.do?currMenu=115&topMenu=100&dataSetSn=624 + \n ์ง€์ •๋œ ๋””๋ ‰ํ† ๋ฆฌ์˜ ๋ชจ๋“  JSON ํŒŒ์ผ์„ ์ฒ˜๋ฆฌํ•˜์—ฌ ํ•˜๋‚˜์˜ JSON ํŒŒ์ผ๋กœ ํ†ตํ•ฉ + Args: + input_dir: ์ž…๋ ฅ JSON ํŒŒ์ผ๋“ค์ด ์žˆ๋Š” ๋””๋ ‰ํ† ๋ฆฌ ๊ฒฝ๋กœ + output_file: ์ถœ๋ ฅ๋  ํ†ตํ•ฉ JSON ํŒŒ์ผ ๊ฒฝ๋กœ + """ + all_documents = [] + input_path = Path(input_dir) + + try: + # ์ž…๋ ฅ ๋””๋ ‰ํ† ๋ฆฌ ๋‚ด์˜ ๋ชจ๋“  JSON ํŒŒ์ผ ์ฒ˜๋ฆฌ + for json_file in input_path.glob("**/*.json"): + logger.info(f"์ฒ˜๋ฆฌ ์ค‘์ธ ํŒŒ์ผ: {json_file}") + + try: + with open(json_file, "r", encoding="utf-8") as f: + data = json.load(f) + + # SJML ๊ตฌ์กฐ ํ™•์ธ ๋ฐ ๋ฐ์ดํ„ฐ ์ถ”์ถœ + if "SJML" in data and "text" in data["SJML"]: + for doc in data["SJML"]["text"]: + processed_doc = {"title": doc["title"], "text": doc["content"]} + all_documents.append(processed_doc) + else: + logger.warning(f"์ž˜๋ชป๋œ JSON ๊ตฌ์กฐ: {json_file}") + + except json.JSONDecodeError: + logger.error(f"JSON ํŒŒ์‹ฑ ์˜ค๋ฅ˜: {json_file}") + except Exception as e: + logger.error(f"ํŒŒ์ผ ์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {json_file}, ์˜ค๋ฅ˜: {str(e)}") + + # ์ตœ์ข… ๊ฒฐ๊ณผ๋ฅผ ๋‹จ์ผ JSON ํŒŒ์ผ๋กœ ์ €์žฅ + if all_documents: + with open(output_file, "w", encoding="utf-8") as f: + json.dump(all_documents, f, ensure_ascii=False, indent=2) + logger.info(f"์ฒ˜๋ฆฌ ์™„๋ฃŒ: ์ด {len(all_documents)}๊ฐœ ๋ฌธ์„œ๊ฐ€ {output_file}์— ์ €์žฅ๋จ") + else: + logger.warning("์ฒ˜๋ฆฌ๋œ ๋ฌธ์„œ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.") + + except Exception as e: + logger.error(f"์ „์ฒด ์ฒ˜๋ฆฌ ๊ณผ์ • ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}") + + +if __name__ == "__main__": + os.chdir("../../") + + PROCESS_JSON_FILE = False + WIKIPEDIA = False + AI_HUB_NEWS_CORPUS = False + + if PROCESS_JSON_FILE: + process_json_file("../data/documents.json") + if WIKIPEDIA: + wikipedia() + if AI_HUB_NEWS_CORPUS: + ai_hub_news_corpus("../data/ai_hub_news_corpus", "../data/ai_hub_news_corpus.json") diff --git a/code/rag/data_process/wiki_dump.py b/code/rag/data_process/wiki_dump.py new file mode 100644 index 0000000..a5ad05e --- /dev/null +++ b/code/rag/data_process/wiki_dump.py @@ -0,0 +1,93 @@ +import json +import os +import re +import urllib.request + +from loguru import logger + + +def dump_wiki(data_path: str = "../data"): + """ + ์œ„ํ‚คํ”ผ๋””์•„ ๋คํ”„๋ฅผ ๋‹ค์šด๋กœ๋“œํ•˜๊ณ  ์ถ”์ถœํ•˜๋Š” ํ•จ์ˆ˜ + """ + dump_filename = "kowiki-latest-pages-articles.xml.bz2" + dump_path = os.path.join(data_path, dump_filename) + wiki_url = f"https://dumps.wikimedia.org/kowiki/latest/{dump_filename}" + + # wget https://dumps.wikimedia.org/kowiki/latest/kowiki-latest-pages-articles.xml.bz2 + if not os.path.exists(dump_path): + logger.debug(f"์œ„ํ‚คํ”ผ๋””์•„ ๋คํ”„๋ฅผ ๋‹ค์šด๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค: {wiki_url}") + urllib.request.urlretrieve(wiki_url, dump_path) + logger.debug(f"๋‹ค์šด๋กœ๋“œ ์™„๋ฃŒ: {dump_path}") + + # python -m wikiextractor.WikiExtractor kowiki-latest-pages-articles.xml.bz2 + extract_dir = os.path.join(data_path, "text") + if not os.path.exists(extract_dir): + logger.debug("WikiExtractor๋กœ ๋คํ”„ ํŒŒ์ผ์„ ์ถ”์ถœํ•ฉ๋‹ˆ๋‹ค...") + os.system(f"python -m wikiextractor.WikiExtractor {dump_path} -o {extract_dir}") + logger.debug("์ถ”์ถœ ์™„๋ฃŒ") + + def _get_filename_list(dirname): + filepaths = [] + for root, dirs, files in os.walk(dirname): + for file in files: + filepath = os.path.join(root, file) + if re.match(r"wiki_[0-9][0-9]", file): + filepaths.append(filepath) + return sorted(filepaths) + + filepaths = _get_filename_list(extract_dir) + output_path = os.path.join(data_path, "wiki_dump.txt") + + # ํŒŒ์ผ ๋‚ด์šฉ ์ฝ๊ธฐ + all_text = "" + for filepath in filepaths: + with open(filepath, "r", encoding="utf-8") as f: + all_text += f.read() + "\n" + + # ์ „์ฒด ํ…์ŠคํŠธ๋ฅผ ํ•˜๋‚˜์˜ ํŒŒ์ผ๋กœ ์ €์žฅ + with open(output_path, "w", encoding="utf-8") as f: + f.write(all_text) + + logger.debug(f"์ด {len(filepaths)}๊ฐœ์˜ ํŒŒ์ผ์„ ์ฒ˜๋ฆฌํ–ˆ์Šต๋‹ˆ๋‹ค.") + logger.debug(f"๋ชจ๋“  ๋‚ด์šฉ์ด {output_path} ํŒŒ์ผ์— ์ €์žฅ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.") + + +def parse_wiki_dump(file_path: str = "../data/wiki_dump.txt"): + """ + ์œ„ํ‚คํ”ผ๋””์•„ ๋คํ”„ ํŒŒ์ผ์„ JSON ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ํ•จ์ˆ˜ + """ + documents = [] + current_doc = "" + doc_id = None + title = None + + with open(file_path, "r", encoding="utf-8") as f: + for line in f: + # ์ƒˆ๋กœ์šด ๋ฌธ์„œ ์‹œ์ž‘ + if line.startswith(""): + if current_doc.strip(): # ๋นˆ ๋ฌธ์„œ๊ฐ€ ์•„๋‹Œ ๊ฒฝ์šฐ๋งŒ ์ถ”๊ฐ€ + documents.append({"id": doc_id, "title": title, "text": current_doc.strip()}) + # ๋ฌธ์„œ ๋‚ด์šฉ + else: + current_doc += line + + # JSON ํŒŒ์ผ๋กœ ์ €์žฅ + output_path = file_path.replace(".txt", ".json") + with open(output_path, "w", encoding="utf-8") as json_file: + json.dump(documents, json_file, ensure_ascii=False, indent=4) + + logger.debug(f"JSON ํŒŒ์ผ์ด ์ƒ์„ฑ๋˜์—ˆ์Šต๋‹ˆ๋‹ค: {output_path}") + logger.debug(f"์ด {len(documents)}๊ฐœ์˜ ๋ฌธ์„œ๊ฐ€ ์ฒ˜๋ฆฌ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.") + return documents + + +if __name__ == "__main__": + os.chdir("../../") + dump_wiki() + parse_wiki_dump() diff --git a/code/rag/dpr_data.py b/code/rag/dpr_data.py new file mode 100644 index 0000000..b13155f --- /dev/null +++ b/code/rag/dpr_data.py @@ -0,0 +1,226 @@ +# from utils import get_passage_file +from glob import glob +import json +import logging +import math +import os +import pickle +import re +import typing +from typing import Iterator, List, Sized, Tuple + +import torch +from torch import tensor as T +from torch.nn.utils.rnn import pad_sequence +from tqdm import tqdm +from transformers import AutoTokenizer + + +def get_wiki_filepath(data_dir): + return glob(f"{data_dir}/*/wiki_*") + + +def wiki_worker_init(worker_id): + worker_info = torch.utils.data.get_worker_info() + dataset = worker_info.dataset + # logger.debug(dataset) + # dataset = + overall_start = dataset.start + overall_end = dataset.end + split_size = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers))) + worker_id = worker_info.id + # end_idx = min((worker_id+1) * split_size, len(dataset.data)) + dataset.start = overall_start + worker_id * split_size + dataset.end = min(dataset.start + split_size, overall_end) # index error ๋ฐฉ์ง€ + + +def get_passage_file(p_id_list: typing.List[int]) -> str: + """passage id๋ฅผ ๋ฐ›์•„์„œ ํ•ด๋‹น๋˜๋Š” ํŒŒ์ผ ์ด๋ฆ„์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.""" + target_file = None + p_id_max = max(p_id_list) + p_id_min = min(p_id_list) + for f in glob("processed_passages/*.p"): + s, e = f.split("/")[1].split(".")[0].split("-") + s, e = int(s), int(e) + if p_id_min >= s and p_id_max <= e: + target_file = f + return target_file + + +# set logger +os.makedirs("logs", exist_ok=True) +logging.basicConfig( + filename="logs/log.log", + level=logging.DEBUG, + format="[%(asctime)s | %(funcName)s @ %(pathname)s] %(message)s", +) +logger = logging.getLogger() + + +def korquad_collator(batch: List[Tuple], padding_value: int) -> Tuple[torch.Tensor]: + """query, p_id, gold_passage๋ฅผ batch๋กœ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.""" + batch_q = pad_sequence([T(e[0]) for e in batch], batch_first=True, padding_value=padding_value) + # logger.debug(batch_q.shape) + batch_q_attn_mask = (batch_q != padding_value).long() + # logger.debug(batch_q_attn_mask.shape) + batch_p_id = T([e[1] for e in batch])[:, None] + # logger.debug(batch_p_id.shape) + batch_p = pad_sequence([T(e[2]) for e in batch], batch_first=True, padding_value=padding_value) + # logger.debug(batch_p.shape) + batch_p_attn_mask = (batch_p != padding_value).long() + return (batch_q, batch_q_attn_mask, batch_p_id, batch_p, batch_p_attn_mask) + + +class KorQuadSampler(torch.utils.data.BatchSampler): + """in-batch negativeํ•™์Šต์„ ์œ„ํ•ด batch ๋‚ด์— ์ค‘๋ณต answer๋ฅผ ๊ฐ–์ง€ ์•Š๋„๋ก batch๋ฅผ ๊ตฌ์„ฑํ•ฉ๋‹ˆ๋‹ค. + sample ์ผ๋ถ€๋ฅผ passํ•˜๊ธฐ ๋•Œ๋ฌธ์— ์ „์ฒด data ์ˆ˜๋ณด๋‹ค iteration์„ ํ†ตํ•ด ๋‚˜์˜ค๋Š” ๋ฐ์ดํ„ฐ ์ˆ˜๊ฐ€ ๋ช‡์‹ญ๊ฐœ ์ •๋„ ์ ์Šต๋‹ˆ๋‹ค.""" + + def __init__( + self, + data_source: Sized, + batch_size: int, + drop_last: bool = False, + shuffle: bool = True, + generator=None, + ) -> None: + if shuffle: + sampler = torch.utils.data.RandomSampler(data_source, replacement=False, generator=generator) + else: + sampler = torch.utils.data.SequentialSampler(data_source) + super(KorQuadSampler, self).__init__(sampler=sampler, batch_size=batch_size, drop_last=drop_last) + + def __iter__(self) -> Iterator[List[int]]: + sampled_p_id = [] + sampled_idx = [] + for idx in self.sampler: + item = self.sampler.data_source[idx] + if item[1] in sampled_p_id: + continue # ๋งŒ์ผ ๊ฐ™์€ answer passage๊ฐ€ ์ด๋ฏธ ๋ฝ‘ํ˜”๋‹ค๋ฉด pass + sampled_idx.append(idx) + sampled_p_id.append(item[1]) + if len(sampled_idx) >= self.batch_size: + yield sampled_idx + sampled_p_id = [] + sampled_idx = [] + if len(sampled_idx) > 0 and not self.drop_last: + yield sampled_idx + + +class KorQuadDataset: + def __init__(self, korquad_path: str, title_passage_map_path="title_passage_map.p"): + self.korquad_path = korquad_path + self.data_tuples = [] + self.tokenizer = AutoTokenizer.from_pretrained("monologg/kobert", trust_remote_code=True) + self.pad_token_id = self.tokenizer.get_vocab()["[PAD]"] + self.load() + + @property + def dataset(self) -> List[Tuple]: + return self.tokenized_tuples + + def stat(self): + """korquad ๋ฐ์ดํ„ฐ์…‹์˜ ์Šคํƒฏ์„ ์ถœ๋ ฅํ•ฉ๋‹ˆ๋‹ค.""" + raise NotImplementedError() + + def load(self): + """๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ๊ฐ€ ์™„๋ฃŒ๋˜์—ˆ๋‹ค๋ฉด loadํ•˜๊ณ  ๊ทธ๋ ‡์ง€ ์•Š์œผ๋ฉด ์ „์ฒ˜๋ฆฌ๋ฅผ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.""" + self.korquad_processed_path = f"{self.korquad_path.split('.json')[0]}_processed.p" + if os.path.exists(self.korquad_processed_path): + logger.debug("preprocessed file already exists, loading...") + with open(self.korquad_processed_path, "rb") as f: + self.tokenized_tuples = pickle.load(f) + logger.debug("successfully loaded tokenized_tuples into self.tokenized_tuples") + + else: + self._load_data() + self._match_passage() + logger.debug("successfully loaded data_tuples into self.data_tuples") + # tokenizing raw dataset + self.tokenized_tuples = [ + (self.tokenizer.encode(q), id, self.tokenizer.encode(p)) + for q, id, p in tqdm(self.data_tuples, desc="tokenize") + ] + self._save_processed_dataset() + logger.debug("finished tokenization") + + def _load_data(self): + with open(self.korquad_path, "rt", encoding="utf8") as f: + data = json.load(f) + self.raw_json = data["data"] + logger.debug("data loaded into self.raw_json") + with open("title_passage_map.p", "rb") as f: + self.title_passage_map = pickle.load(f) + logger.debug("title passage mapping loaded into self.title_passage_map") + + def _get_cand_ids(self, title): + """๋ฏธ๋ฆฌ ๊ตฌ์ถ•ํ•œ ko-wiki ๋ฐ์ดํ„ฐ์—์„œ ํ•ด๋‹น title์— ๋งž๋Š” id๋“ค์„ ๊ฐ€์ง€๊ณ  ์˜ต๋‹ˆ๋‹ค.""" + refined_title = None + ret = self.title_passage_map.get(title, None) + if not ret: + refined_title = re.sub(r"\(.*\)", "", title).strip() + ret = self.title_passage_map.get(refined_title, None) + return ret, refined_title + + def _match_passage(self): + """๋ฏธ๋ฆฌ ๊ตฌ์ถ•ํ•œ ko-wiki ๋ฐ์ดํ„ฐ์™€ korQuad์˜ answer๋ฅผ ๋งค์นญํ•˜์—ฌ + (query, passage_id, passage)์˜ tuple์„ ๊ตฌ์„ฑํ•ฉ๋‹ˆ๋‹ค.""" + for item in tqdm(self.raw_json, desc="matching silver passages"): + title = item["title"].replace("_", " ") # _๋ฅผ ๊ณต๋ฐฑ๋ฌธ์ž๋กœ ๋ณ€๊ฒฝ + para = item["paragraphs"] + cand_ids, refined_title = self._get_cand_ids(title) + if refined_title is not None and cand_ids: + logger.debug(f"refined the title and proceed : {title} -> {refined_title}") + if cand_ids is None: + logger.debug(f"No such title as {title} or {refined_title}. passing this title") + continue + target_file_p = get_passage_file(cand_ids) + if target_file_p is None: + logger.debug(f"No single target file for {title}, got passage ids {cand_ids}. passing this title") + continue + with open(target_file_p, "rb") as f: + target_file = pickle.load(f) + contexts = {cand_id: target_file[cand_id] for cand_id in cand_ids} + + for p in para: + qas = p["qas"] + for qa in qas: + answer = qa["answers"][0]["text"] # ์•„๋ฌด ์ •๋‹ต์ด๋‚˜ ๋ฝ‘์Šต๋‹ˆ๋‹ค. + answer_pos = qa["answers"][0]["answer_start"] + answer_clue_start = max(0, answer_pos - 5) + answer_clue_end = min(len(p["context"]), answer_pos + len(answer) + 5) + answer_clue = p["context"][ + answer_clue_start:answer_clue_end + ] # gold passage๋ฅผ ์ฐพ๊ธฐ ์œ„ํ•ด์„œ +-5์นธ์˜ ์ฃผ๋ณ€ text ํ™œ์šฉ + question = qa["question"] + answer_p = [ + (p_id, c) for p_id, c in contexts.items() if answer_clue in c + ] # answer๊ฐ€ ๋‹จ์ˆœํžˆ ๋“ค์–ด์žˆ๋Š” ๋ฌธ์„œ๋ฅผ ๋ฝ‘๋Š”๋‹ค. + if not answer_p: + answer_p = [(p_id, c) for p_id, c in contexts.items() if answer in c] + + self.data_tuples.extend([(question, p_id, c) for p_id, c in answer_p]) + + def _save_processed_dataset(self): + """์ „์ฒ˜๋ฆฌํ•œ ๋ฐ์ดํ„ฐ๋ฅผ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.""" + with open(self.korquad_processed_path, "wb") as f: + pickle.dump(self.tokenized_tuples, f) + logger.debug(f"successfully saved self.tokenized_tuples into {self.korquad_processed_path}") + + +# if __name__ == "__main__": +# ds = KorQuadDataset(korquad_path="./data/KorQuAD_v1.0_train.json") + +# loader = torch.utils.data.DataLoader( +# dataset=ds.dataset, +# batch_sampler=KorQuadSampler(ds.dataset, batch_size=16, drop_last=False), +# collate_fn=lambda x: korquad_collator(x, padding_value=ds.pad_token_id), +# num_workers=4, +# ) +# # logger.debug(len(_dataset.tokenized_tuples)) +# torch.manual_seed(123412341235) +# cnt = 0 +# for batch in tqdm(loader): +# #logger.debug(len(batch)) +# cnt += batch[0].size(0) +# # break +# logger.debug(cnt) diff --git a/code/rag/encoder.py b/code/rag/encoder.py new file mode 100644 index 0000000..8bb1593 --- /dev/null +++ b/code/rag/encoder.py @@ -0,0 +1,62 @@ +from copy import deepcopy +import logging +import os + +import torch +from transformers import BertModel + + +# ๋‘ ๊ฐœ์˜ BertModel์„ ์‚ฌ์šฉํ•˜์—ฌ passage์™€ query๋ฅผ encoding์„ ์‹คํ–‰ +# ํ† ํฌ๋‚˜์ด์ง• ํ›„์— ํ† ํฐ์„ ๊ณ ์ •๋œ ํฌ๊ธฐ์˜ ๋ฒกํ„ฐ๋กœ ๋ณ€๊ฒฝ + +# ๋กœ๊ทธ ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ (์—†์œผ๋ฉด ์ƒˆ๋กœ ์ƒ์„ฑ) +os.makedirs("logs", exist_ok=True) + +# ๋กœ๊น… ์„ค์ •: ๋กœ๊ทธ๋ฅผ ํŒŒ์ผ๋กœ ์ €์žฅํ•˜๊ณ  ๋””๋ฒ„๊น… ๋ ˆ๋ฒจ๋กœ ์„ค์ • +logging.basicConfig( + filename="logs/log.log", + level=logging.DEBUG, + format="[%(asctime)s | %(funcName)s @ %(pathname)s] %(message)s", +) +logger = logging.getLogger() + + +# KobertBiEncoder ํด๋ž˜์Šค ์ •์˜ +class KobertBiEncoder(torch.nn.Module): + def __init__(self): + # torch.nn.Module์˜ ์ดˆ๊ธฐํ™” ํ•จ์ˆ˜ ํ˜ธ์ถœ + super(KobertBiEncoder, self).__init__() + # passage(๋ฌธ์„œ)๋ฅผ ์ฒ˜๋ฆฌํ•˜๋Š” BERT ๋ชจ๋ธ + self.passage_encoder = BertModel.from_pretrained("monologg/kobert", trust_remote_code=True) + # query(์งˆ์˜)๋ฅผ ์ฒ˜๋ฆฌํ•˜๋Š” BERT ๋ชจ๋ธ + self.query_encoder = BertModel.from_pretrained("monologg/kobert", trust_remote_code=True) + # BERT ๋ชจ๋ธ์˜ pooler output(์ž„๋ฒ ๋”ฉ ํฌ๊ธฐ) ์„ค์ • + self.emb_sz = self.passage_encoder.pooler.dense.out_features # get cls token dim + + def forward(self, x: torch.LongTensor, attn_mask: torch.LongTensor, type: str = "passage") -> torch.FloatTensor: + """passage ๋˜๋Š” query๋ฅผ BERT๋กœ ์ธ์ฝ”๋”ฉํ•ฉ๋‹ˆ๋‹ค.""" + # type์ด 'passage' ๋˜๋Š” 'query'์ธ์ง€ ํ™•์ธ + assert type in ( + "passage", + "query", + ), "type should be either 'passage' or 'query'" + # type์— ๋”ฐ๋ผ ๋‹ค๋ฅธ ์ธ์ฝ”๋” ์‚ฌ์šฉ + if type == "passage": + # ๋ฌธ์„œ(passage) ์ธ์ฝ”๋”ฉ + return self.passage_encoder(input_ids=x, attention_mask=attn_mask).pooler_output + else: + # ์งˆ์˜(query) ์ธ์ฝ”๋”ฉ + return self.query_encoder(input_ids=x, attention_mask=attn_mask).pooler_output + + def checkpoint(self, model_ckpt_path): + # ๋ชจ๋ธ์˜ ๊ฐ€์ค‘์น˜๋ฅผ ํŒŒ์ผ๋กœ ์ €์žฅ + torch.save(deepcopy(self.state_dict()), model_ckpt_path) + logger.debug(f"model self.state_dict saved to {model_ckpt_path}") + + def load(self, model_ckpt_path): + # ์ €์žฅ๋œ ๊ฐ€์ค‘์น˜๋ฅผ ํŒŒ์ผ์—์„œ ๋กœ๋“œ + with open(model_ckpt_path, "rb") as f: + state_dict = torch.load(f) + # ๋ชจ๋ธ์— ๋กœ๋“œ๋œ ๊ฐ€์ค‘์น˜ ์ ์šฉ + self.load_state_dict(state_dict) + logger.debug(f"model self.state_dict loaded from {model_ckpt_path}") diff --git a/code/rag/index_runner.py b/code/rag/index_runner.py new file mode 100644 index 0000000..6a28f73 --- /dev/null +++ b/code/rag/index_runner.py @@ -0,0 +1,188 @@ +from glob import glob +import logging +import math +import os +import typing +from typing import List, Tuple + +from chunk_data import DataChunk +from encoder import KobertBiEncoder +import indexers +import torch +from torch import tensor as T +from torch.nn.utils.rnn import pad_sequence +from tqdm import tqdm +import transformers + + +# from utils import get_wiki_filepath, wiki_worker_init +transformers.logging.set_verbosity_error() # ํ† ํฌ๋‚˜์ด์ € ์ดˆ๊ธฐํ™” ๊ด€๋ จ warning suppress + + +def get_wiki_filepath(data_dir): + return glob(f"{data_dir}/*/wiki_*") + + +def wiki_worker_init(worker_id): + worker_info = torch.utils.data.get_worker_info() + dataset = worker_info.dataset + # print(dataset) + # dataset = + overall_start = dataset.start + overall_end = dataset.end + split_size = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers))) + worker_id = worker_info.id + # end_idx = min((worker_id+1) * split_size, len(dataset.data)) + dataset.start = overall_start + worker_id * split_size + dataset.end = min(dataset.start + split_size, overall_end) # index error ๋ฐฉ์ง€ + + +def get_passage_file(p_id_list: typing.List[int]) -> str: + """passage id๋ฅผ ๋ฐ›์•„์„œ ํ•ด๋‹น๋˜๋Š” ํŒŒ์ผ ์ด๋ฆ„์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.""" + target_file = None + p_id_max = max(p_id_list) + p_id_min = min(p_id_list) + for f in glob("processed_passages/*.p"): + s, e = f.split("/")[1].split(".")[0].split("-") + s, e = int(s), int(e) + if p_id_min >= s and p_id_max <= e: + target_file = f + return target_file + + +# logger basic config +os.makedirs("logs", exist_ok=True) +logging.basicConfig( + filename="logs/log.log", + level=logging.DEBUG, + format="[%(asctime)s | %(funcName)s @ %(pathname)s] %(message)s", +) +logger = logging.getLogger() + + +def wiki_collator(batch: List, padding_value: int) -> Tuple[torch.Tensor]: + """passage๋ฅผ batch๋กœ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.""" + batch_p = pad_sequence([T(e) for e in batch], batch_first=True, padding_value=padding_value) + batch_p_attn_mask = (batch_p != padding_value).long() + return (batch_p, batch_p_attn_mask) + + +class WikiArticleStream(torch.utils.data.IterableDataset): + """ + Indexing์„ ์œ„ํ•ด random access๊ฐ€ ํ•„์š”ํ•˜์ง€ ์•Š๊ณ  large corpus๋ฅผ ๋‹ค๋ฃจ๊ธฐ ์œ„ํ•ด stream dataset์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. + """ + + def __init__(self, wiki_path, chunker): + # self.chunk_size = chunk_size + super(WikiArticleStream, self).__init__() + self.chunker = chunker + self.pad_token_id = self.chunker.tokenizer.get_vocab()["[PAD]"] + self.wiki_path = wiki_path + self.max_length = 168 # maximum length for kowiki passage + # self.start = 0 + # self.end = len(self.passages) + + def __iter__(self): + # max_length๊ฐ€ ๋˜๋„๋ก padding ์ˆ˜ํ–‰ + + _, passages = self.chunker.chunk(self.wiki_path) + logger.debug(f"chunked file {self.wiki_path}") + for passage in passages: + # if len(passage) > self.max_length: + # continue # ์ง€์ •๋œ max_length๋ณด๋‹ค ๊ธด passage์˜ ๊ฒฝ์šฐ pass + # padded_passage = T( + # passage + # + [self.pad_token_id for _ in range(self.max_length - len(passage))] + # ) + yield passage + + +class IndexRunner: + """์ฝ”ํผ์Šค์— ๋Œ€ํ•œ ์ธ๋ฑ์‹ฑ์„ ์ˆ˜ํ–‰ํ•˜๋Š” ๋ฉ”์ธํด๋ž˜์Šค์ž…๋‹ˆ๋‹ค. + passage encoder์™€ data loader, FAISS indexer๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.""" + + def __init__( + self, + data_dir: str, + model_ckpt_path: str, + indexer_type: str = "DenseFlatIndexer", + chunk_size: int = 100, + batch_size: int = 64, + buffer_size: int = 50000, + index_output: str = "", + device: str = "cuda:0", + ): + """ + data_dir : ์ธ๋ฑ์‹ฑํ•  ํ•œ๊ตญ์–ด wiki ๋ฐ์ดํ„ฐ๊ฐ€ ๋“ค์–ด์žˆ๋Š” ๋””๋ ‰ํ† ๋ฆฌ์ž…๋‹ˆ๋‹ค. ํ•˜์œ„์— AA, AB์™€ ๊ฐ™์€ ๋””๋ ‰ํ† ๋ฆฌ๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค. + indexer_type : ์‚ฌ์šฉํ•  FAISS indexer ์ข…๋ฅ˜๋กœ + DPR ๋ฆฌํฌ์— ์žˆ๋Š” ๋Œ€๋กœ Flat, HNSWFlat, HNSWSQ ์„ธ ์ข…๋ฅ˜ ์ค‘์— ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. + chunk_size : indexing๊ณผ searching์˜ ๋‹จ์œ„๊ฐ€ ๋˜๋Š” passage์˜ ๊ธธ์ด์ž…๋‹ˆ๋‹ค. + DPR ๋…ผ๋ฌธ์—์„œ๋Š” 100๊ฐœ token ๊ธธ์ด + title๋กœ ํ•˜๋‚˜์˜ passage๋ฅผ ์ •์˜ํ•˜์˜€์Šต๋‹ˆ๋‹ค. + """ + if "=" in data_dir: + self.data_dir, self.to_this_page = data_dir.split("=") + self.to_this_page = int(self.to_this_page) + self.wiki_files = get_wiki_filepath(self.data_dir) + else: + self.data_dir = data_dir + self.wiki_files = get_wiki_filepath(self.data_dir) + self.to_this_page = len(self.wiki_files) + + self.device = torch.device(device) + self.encoder = KobertBiEncoder().to(self.device) + self.encoder.load(model_ckpt_path) # loading model + self.indexer = getattr(indexers, indexer_type)() + self.chunk_size = chunk_size + self.batch_size = batch_size + self.buffer_size = buffer_size + self.loader = self.get_loader( + self.wiki_files[: self.to_this_page], + chunk_size, + batch_size, + worker_init_fn=None, + ) + self.indexer.init_index(self.encoder.emb_sz) + self.index_output = index_output if index_output else indexer_type + + @staticmethod + def get_loader(wiki_files, chunk_size, batch_size, worker_init_fn=None): + chunker = DataChunk(chunk_size=chunk_size) + ds = torch.utils.data.ChainDataset(tuple(WikiArticleStream(path, chunker) for path in wiki_files)) + loader = torch.utils.data.DataLoader( + ds, + batch_size=batch_size, + collate_fn=lambda x: wiki_collator(x, padding_value=chunker.tokenizer.get_vocab()["[PAD]"]), + num_workers=1, + worker_init_fn=worker_init_fn, + ) # TODO : chain dataset์—์„œ worker 1 ์ดˆ๊ณผ์ธ ๊ฒฝ์šฐ ํ™•์ธํ•˜๊ธฐ + return loader + + def run(self): + _to_index = [] + cur = 0 + for batch in tqdm(self.loader, desc="indexing"): + p, p_mask = batch + p, p_mask = p.to(self.device), p_mask.to(self.device) + with torch.no_grad(): + p_emb = self.encoder(p, p_mask, "passage") + _to_index += [(cur + i, _emb) for i, _emb in enumerate(p_emb.cpu().numpy())] + cur += p_emb.size(0) + if len(_to_index) > self.buffer_size - self.batch_size: + logger.debug(f"perform indexing... {len(_to_index)} added") + self.indexer.index_data(_to_index) + _to_index = [] + if _to_index: + logger.debug(f"perform indexing... {len(_to_index)} added") + self.indexer.index_data(_to_index) + _to_index = [] + os.makedirs(self.index_output, exist_ok=True) + self.indexer.serialize(self.index_output) + + +# if __name__ == "__main__": +# IndexRunner( +# data_dir="./dataset/text", +# model_ckpt_path="./my_model.pt", +# index_output="2050iter_flat", +# ).run() +# # test_loader() diff --git a/code/rag/indexers.py b/code/rag/indexers.py new file mode 100644 index 0000000..e503a28 --- /dev/null +++ b/code/rag/indexers.py @@ -0,0 +1,233 @@ +# Credit : facebookresearch/DPR + +""" +FAISS-based index components for dense retriever +""" + +import logging +import os +import pickle +from typing import List, Tuple + +import faiss +import numpy as np + + +logger = logging.getLogger() + + +class DenseIndexer(object): + def __init__(self, buffer_size: int = 50000): + """ + ์ธ๋ฑ์„œ๋ฅผ ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค. + - buffer_size: ํ•œ ๋ฒˆ์— ์ฒ˜๋ฆฌํ•  ๋ฒกํ„ฐ์˜ ์ตœ๋Œ€ ๊ฐœ์ˆ˜ (๊ธฐ๋ณธ๊ฐ’์€ 50000). + - index_id_to_db_id: FAISS์—์„œ ์ธ๋ฑ์Šค ID์™€ ์‹ค์ œ ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ID๋ฅผ ๋งคํ•‘ํ•˜๊ธฐ ์œ„ํ•œ ๋ฆฌ์ŠคํŠธ. + - index: FAISS ์ธ๋ฑ์Šค ๊ฐ์ฒด. + """ + self.buffer_size = buffer_size + self.index_id_to_db_id = [] # ์ธ๋ฑ์Šค ID๋ฅผ ์‹ค์ œ ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค ID์™€ ๋งคํ•‘ํ•˜๋Š” ๋ฆฌ์ŠคํŠธ. + self.index = None # FAISS ์ธ๋ฑ์Šค ๊ฐ์ฒด. + + def init_index(self, vector_sz: int): + raise NotImplementedError + + def index_data(self, data: List[Tuple[object, np.array]]): + raise NotImplementedError + + def get_index_name(self): + raise NotImplementedError + + def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]: + raise NotImplementedError + + def serialize(self, file: str): + logger.info("Serializing index to %s", file) + + if os.path.isdir(file): + index_file = os.path.join(file, "index.dpr") + meta_file = os.path.join(file, "index_meta.dpr") + else: + index_file = file + ".index.dpr" + meta_file = file + ".index_meta.dpr" + + faiss.write_index(self.index, index_file) # FAISS ์ธ๋ฑ์Šค๋ฅผ ํŒŒ์ผ์— ์ €์žฅ. + with open(meta_file, mode="wb") as f: + pickle.dump(self.index_id_to_db_id, f) # ID ๋งคํ•‘ ์ •๋ณด๋ฅผ ์ €์žฅ. + + def get_files(self, path: str): + if os.path.isdir(path): + index_file = os.path.join(path, "index.dpr") # FAISS ์ธ๋ฑ์Šค๋ฅผ ํŒŒ์ผ์—์„œ ๋กœ๋“œ. + meta_file = os.path.join(path, "index_meta.dpr") + else: + index_file = path + ".{}.dpr".format(self.get_index_name()) + meta_file = path + ".{}_meta.dpr".format(self.get_index_name()) + return index_file, meta_file + + def index_exists(self, path: str): + index_file, meta_file = self.get_files(path) + return os.path.isfile(index_file) and os.path.isfile(meta_file) + + def deserialize(self, path: str): + logger.info("Loading index from %s", path) + index_file, meta_file = self.get_files(path) + + self.index = faiss.read_index(index_file) + logger.info("Loaded index of type %s and size %d", type(self.index), self.index.ntotal) + + with open(meta_file, "rb") as reader: + self.index_id_to_db_id = pickle.load(reader) + assert ( + len(self.index_id_to_db_id) == self.index.ntotal + ), "Deserialized index_id_to_db_id should match faiss index size" + + def _update_id_mapping(self, db_ids: List) -> int: + self.index_id_to_db_id.extend(db_ids) + return len(self.index_id_to_db_id) + + +class DenseFlatIndexer(DenseIndexer): + def __init__(self, buffer_size: int = 50000): + super(DenseFlatIndexer, self).__init__(buffer_size=buffer_size) + + def init_index(self, vector_sz: int): + self.index = faiss.IndexFlatIP(vector_sz) # Inner Product๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ธฐ๋ณธ ์ธ๋ฑ์Šค ์ดˆ๊ธฐํ™”. + + def index_data(self, data: List[Tuple[object, np.array]]): + n = len(data) + # indexing in batches is beneficial for many faiss index types + for i in range(0, n, self.buffer_size): + db_ids = [t[0] for t in data[i : i + self.buffer_size]] + vectors = [np.reshape(t[1], (1, -1)) for t in data[i : i + self.buffer_size]] + vectors = np.concatenate(vectors, axis=0) + total_data = self._update_id_mapping(db_ids) + self.index.add(vectors) + logger.info("data indexed %d", total_data) + + indexed_cnt = len(self.index_id_to_db_id) + logger.info("Total data indexed %d", indexed_cnt) + + def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]: + scores, indexes = self.index.search(query_vectors, top_docs) # ์ฟผ๋ฆฌ ๋ฒกํ„ฐ์™€ ๊ฐ€์žฅ ์œ ์‚ฌํ•œ ๋ฒกํ„ฐ ๊ฒ€์ƒ‰. + # convert to external ids + db_ids = [[self.index_id_to_db_id[i] for i in query_top_idxs] for query_top_idxs in indexes] + result = [(db_ids[i], scores[i]) for i in range(len(db_ids))] + return result + + def get_index_name(self): + return "flat_index" + + +class DenseHNSWFlatIndexer(DenseIndexer): + """ + Efficient index for retrieval. Note: default settings are for hugh accuracy but also high RAM usage + """ + + def __init__( + self, + buffer_size: int = 1e9, + store_n: int = 512, + ef_search: int = 128, + ef_construction: int = 200, + ): + super(DenseHNSWFlatIndexer, self).__init__(buffer_size=buffer_size) + self.store_n = store_n + self.ef_search = ef_search + self.ef_construction = ef_construction + self.phi = 0 + + def init_index(self, vector_sz: int): + # IndexHNSWFlat supports L2 similarity only + # so we have to apply DOT -> L2 similairy space conversion with the help of an extra dimension + index = faiss.IndexHNSWFlat(vector_sz + 1, self.store_n) # L2 ๊ฑฐ๋ฆฌ ๊ธฐ๋ฐ˜ HNSW ์ธ๋ฑ์Šค ์ดˆ๊ธฐํ™”. + index.hnsw.efSearch = self.ef_search + index.hnsw.efConstruction = self.ef_construction + self.index = index + + def index_data(self, data: List[Tuple[object, np.array]]): + n = len(data) + + # max norm is required before putting all vectors in the index to convert inner product similarity to L2 + if self.phi > 0: + raise RuntimeError( + "DPR HNSWF index needs to index all data at once," "results will be unpredictable otherwise." + ) + phi = 0 + for i, item in enumerate(data): + id, doc_vector = item[0:2] + norms = (doc_vector**2).sum() + phi = max(phi, norms) + logger.info("HNSWF DotProduct -> L2 space phi={}".format(phi)) + self.phi = phi + + # indexing in batches is beneficial for many faiss index types + bs = int(self.buffer_size) + for i in range(0, n, bs): + db_ids = [t[0] for t in data[i : i + bs]] + vectors = [np.reshape(t[1], (1, -1)) for t in data[i : i + bs]] + + norms = [(doc_vector**2).sum() for doc_vector in vectors] + aux_dims = [np.sqrt(phi - norm) for norm in norms] + hnsw_vectors = [np.hstack((doc_vector, aux_dims[i].reshape(-1, 1))) for i, doc_vector in enumerate(vectors)] + hnsw_vectors = np.concatenate(hnsw_vectors, axis=0) + self.train(hnsw_vectors) + + self._update_id_mapping(db_ids) + self.index.add(hnsw_vectors) + logger.info("data indexed %d", len(self.index_id_to_db_id)) + indexed_cnt = len(self.index_id_to_db_id) + logger.info("Total data indexed %d", indexed_cnt) + + def train(self, vectors: np.array): + pass + + def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]: + aux_dim = np.zeros(len(query_vectors), dtype="float32") + query_nhsw_vectors = np.hstack((query_vectors, aux_dim.reshape(-1, 1))) + logger.info("query_hnsw_vectors %s", query_nhsw_vectors.shape) + scores, indexes = self.index.search(query_nhsw_vectors, top_docs) + # convert to external ids + db_ids = [[self.index_id_to_db_id[i] for i in query_top_idxs] for query_top_idxs in indexes] + result = [(db_ids[i], scores[i]) for i in range(len(db_ids))] + return result + + def deserialize(self, file: str): + super(DenseHNSWFlatIndexer, self).deserialize(file) + # to trigger exception on subsequent indexing + self.phi = 1 + + def get_index_name(self): + return "hnsw_index" + + +class DenseHNSWSQIndexer(DenseHNSWFlatIndexer): + """ + Efficient index for retrieval. Note: default settings are for hugh accuracy but also high RAM usage + """ + + def __init__( + self, + buffer_size: int = 1e10, + store_n: int = 128, + ef_search: int = 128, + ef_construction: int = 200, + ): + super(DenseHNSWSQIndexer, self).__init__( + buffer_size=buffer_size, + store_n=store_n, + ef_search=ef_search, + ef_construction=ef_construction, + ) + + def init_index(self, vector_sz: int): + # IndexHNSWFlat supports L2 similarity only + # so we have to apply DOT -> L2 similairy space conversion with the help of an extra dimension + index = faiss.IndexHNSWSQ(vector_sz + 1, faiss.ScalarQuantizer.QT_8bit, self.store_n) + index.hnsw.efSearch = self.ef_search + index.hnsw.efConstruction = self.ef_construction + self.index = index + + def train(self, vectors: np.array): + self.index.train(vectors) + + def get_index_name(self): + return "hnswsq_index" diff --git a/code/rag/prepare_dense.py b/code/rag/prepare_dense.py new file mode 100644 index 0000000..3f81dfe --- /dev/null +++ b/code/rag/prepare_dense.py @@ -0,0 +1,135 @@ +import logging +import os + +from chunk_data import save_orig_passage, save_title_index_map +from dpr_data import KorQuadDataset +from encoder import KobertBiEncoder + +# ์™ธ๋ถ€ ์Šคํฌ๋ฆฝํŠธ์—์„œ IndexRunner ์ž„ํฌํŠธ +from index_runner import IndexRunner +from indexers import DenseFlatIndexer # index ๊ด€๋ จ +from retriever import KorDPRRetriever # retriever.py์—์„œ ๊ฐ€์ ธ์˜ค๊ธฐ +import torch +from trainer import Trainer +import transformers + + +transformers.logging.set_verbosity_error() # ํ† ํฌ๋‚˜์ด์ € ์ดˆ๊ธฐํ™” ๊ด€๋ จ ๊ฒฝ๊ณ  ์–ต์ œ + +# ๋กœ๊น… ์„ค์ • +os.makedirs("logs", exist_ok=True) +logging.basicConfig( + filename="logs/log.log", + level=logging.DEBUG, + format="[%(asctime)s | %(funcName)s @ %(pathname)s] %(message)s", +) +logger = logging.getLogger() + + +# ๋ชจ๋ธ ์กด์žฌ ์—ฌ๋ถ€ ํ™•์ธ ํ•จ์ˆ˜ +def check_if_model_exists(model_path: str): + """๋ชจ๋ธ ์ฒดํฌํฌ์ธํŠธ๊ฐ€ ์กด์žฌํ•˜๋Š”์ง€ ํ™•์ธํ•˜๋Š” ํ•จ์ˆ˜""" + return os.path.exists(model_path) + + +# ์œ„ํ‚ค ๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ ํ•จ์ˆ˜ +def process_wiki_data(): + """ + ์œ„ํ‚ค ๋ฐ์ดํ„ฐ๋ฅผ ์ฒ˜๋ฆฌํ•˜๊ณ  ํ•„์š”ํ•œ ํ”ผํด ํŒŒ์ผ์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค. + """ + processed_passages_path = "processed_passages" + + if not os.path.exists(processed_passages_path): + logger.debug(f"'{processed_passages_path}' ํด๋”๊ฐ€ ์กด์žฌํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ๋ฅผ ์‹œ์ž‘ํ•ฉ๋‹ˆ๋‹ค.") + + # 1. chunk ๋ฐ์ดํ„ฐ๋ฅผ ํ”ผํด ํŒŒ์ผ๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ processed_passages ํด๋”์— ์ €์žฅ (์•ฝ 10๋ถ„ ์†Œ์š”) + save_orig_passage() + + # 2. ์ œ๋ชฉ๊ณผ ์ธ๋ฑ์Šค ๋งคํ•‘ ์ €์žฅ + save_title_index_map() + + logger.debug("๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ๊ฐ€ ์™„๋ฃŒ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.") + else: + logger.debug(f"'{processed_passages_path}' ํด๋”๊ฐ€ ์ด๋ฏธ ์กด์žฌํ•ฉ๋‹ˆ๋‹ค. ๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ๋ฅผ ๊ฑด๋„ˆ๋œ๋‹ˆ๋‹ค.") + + +if __name__ == "__main__": + # ์œ„ํ‚ค ๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ + # processed_passage ํด๋” ๋‚ด์— ํ”ผํดํ™”๋œ ๋ฐ์ดํ„ฐ๊ฐ€ ์ €์žฅ๋ฉ๋‹ˆ๋‹ค. 10๋ถ„ ์†Œ์š” + process_wiki_data() + + # ๋ชจ๋ธ ๊ฒฝ๋กœ ์„ค์ • + model_path = "./output/my_model.pt" + + # korquad ๋ฐ์ดํ„ฐ๋กœ ๋ชจ๋ธ์„ ํ•™์Šต์‹œ์ผœ์ค๋‹ˆ๋‹ค. + # ๋ชจ๋ธ์ด ์ด๋ฏธ ์กด์žฌํ•˜๋ฉด ํ•™์Šต์„ ๊ฑด๋„ˆ๋œ๋‹ˆ๋‹ค + if check_if_model_exists(model_path): + logger.debug(f"์ด๋ฏธ ํ•™์Šต๋œ ๋ชจ๋ธ์ด {model_path}์— ์กด์žฌํ•ฉ๋‹ˆ๋‹ค. ํ•™์Šต์„ ๊ฑด๋„ˆ๋œ๋‹ˆ๋‹ค.") + else: + logger.debug("ํ•™์Šต๋œ ๋ชจ๋ธ์ด ์—†์Šต๋‹ˆ๋‹ค. ํ•™์Šต์„ ์‹œ์ž‘ํ•ฉ๋‹ˆ๋‹ค.") + + # ๋ชจ๋ธ๊ณผ ๋ฐ์ดํ„ฐ์…‹ ์ค€๋น„ + device = torch.device("cuda:0") + model = KobertBiEncoder() + train_dataset = KorQuadDataset("./data/KorQuAD_v1.0_train.json") + valid_dataset = KorQuadDataset("./data/KorQuAD_v1.0_dev.json") + + # Trainer ๊ฐ์ฒด ์ƒ์„ฑ ๋ฐ ํ•™์Šต ์‹œ์ž‘ + my_trainer = Trainer( + model=model, + device=device, + train_dataset=train_dataset, + valid_dataset=valid_dataset, + num_epoch=10, + batch_size=128 - 32, + lr=1e-5, + betas=(0.9, 0.99), + num_warmup_steps=1000, + num_training_steps=100000, + valid_every=30, + best_val_ckpt_path=model_path, + ) + + # ํ•™์Šต ์ƒํƒœ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ + # my_trainer.load_training_state() + + # ํ•™์Šต ์‹œ์ž‘ + my_trainer.fit() + + # Indexing์„ ์‹คํ–‰ํ•˜๋Š” ์ฝ”๋“œ (IndexRunner ์‚ฌ์šฉ) + index_path = "./2050iter_flat" # ์ธ๋ฑ์Šค ํŒŒ์ผ ๊ฒฝ๋กœ ์„ค์ • + + # ์ธ๋ฑ์Šค๊ฐ€ ์ด๋ฏธ ์กด์žฌํ•˜๋ฉด ์ธ๋ฑ์‹ฑ์„ ๊ฑด๋„ˆ๋œ๋‹ˆ๋‹ค + if not os.path.exists(index_path): + logger.info("์ธ๋ฑ์Šค๊ฐ€ ์กด์žฌํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ์ธ๋ฑ์‹ฑ์„ ์‹œ์ž‘ํ•ฉ๋‹ˆ๋‹ค.") + index_runner = IndexRunner( + data_dir="./text", + model_ckpt_path="./output/my_model.pt", + index_output=index_path, + ) + index_runner.run() + else: + logger.info(f"์ธ๋ฑ์Šค ํŒŒ์ผ '{index_path}'๊ฐ€ ์ด๋ฏธ ์กด์žฌํ•ฉ๋‹ˆ๋‹ค. ์ธ๋ฑ์‹ฑ์„ ๊ฑด๋„ˆ๋œ๋‹ˆ๋‹ค.") + + # index ํŒŒ์ผ ๋กœ๋”ฉ + index = DenseFlatIndexer() + index.deserialize(path=index_path) # ์ด๋ฏธ ์ƒ์„ฑ๋œ ์ธ๋ฑ์Šค ํŒŒ์ผ์„ ๋กœ๋“œ + + # retriever.py๋กœ๋ถ€ํ„ฐ KorDPRRetriever ๊ฐ์ฒด๋ฅผ ์ƒ์„ฑํ•˜์—ฌ ์ฟผ๋ฆฌ ์‹คํ–‰ + model = KobertBiEncoder() + model.load("./output/my_model.pt") + model.eval() + + valid_dataset = KorQuadDataset("./data/KorQuAD_v1.0_dev.json") + retriever = KorDPRRetriever(model=model, valid_dataset=valid_dataset, index=index) + + # 'query'์™€ 'k' ๊ฐ’์„ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค. + query = "์ค‘๊ตญ์˜ ์ฒœ์•ˆ๋ฌธ ์‚ฌํƒœ๊ฐ€ ์ผ์–ด๋‚œ ๋…„๋„๋Š”?" + k = 10 # ์ƒ์œ„ 10๊ฐœ ์œ ์‚ฌํ•œ passage๋ฅผ ์ถœ๋ ฅ + + # retrieve ๋ฉ”์„œ๋“œ๋ฅผ ํ˜ธ์ถœํ•˜์—ฌ ๊ฐ€์žฅ ์œ ์‚ฌ๋„๊ฐ€ ๋†’์€ k๊ฐœ์˜ passage๋ฅผ ์ฐพ์Šต๋‹ˆ๋‹ค. + passages = retriever.retrieve(query=query, k=k) + + # ์ถœ๋ ฅ: ์œ ์‚ฌ๋„ ๋†’์€ passage์™€ ๊ทธ ์œ ์‚ฌ๋„๋ฅผ ์ถœ๋ ฅํ•ฉ๋‹ˆ๋‹ค. + for idx, (passage, sim) in enumerate(passages): + logger.debug(f"Rank {idx + 1} | Similarity: {sim:.4f} | Passage: {passage}") diff --git a/code/rag/reranker.py b/code/rag/reranker.py new file mode 100644 index 0000000..2716a23 --- /dev/null +++ b/code/rag/reranker.py @@ -0,0 +1,105 @@ +import gc +import os +from typing import Dict, List + +from dotenv import load_dotenv +from loguru import logger +import numpy as np +import torch +from tqdm import tqdm +from transformers import AutoModelForSequenceClassification, AutoTokenizer + +from .retriever_elastic import ElasticsearchRetriever + + +class Reranker: + def __init__( + self, + model_path: str = "Dongjin-kr/ko-reranker", + batch_size: int = 128, + max_length: int = 512, + ): + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model = AutoModelForSequenceClassification.from_pretrained(model_path) + self.model.to(self.device) + self.model.eval() + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + self.batch_size = batch_size + self.max_length = max_length + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + # GPU ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ + del self.model + del self.tokenizer + torch.cuda.empty_cache() + gc.collect() + + def _exp_normalize(self, x): + y = np.exp(x - x.max(axis=1, keepdims=True)) + return y / y.sum(axis=1, keepdims=True) + + def rerank(self, queries: List[str], retrieve_results: List[List[Dict]], topk: int = 5) -> List[List[Dict]]: + # ์ž…๋ ฅ ๋ฐ์ดํ„ฐ ์ค€๋น„ + all_pairs = [] + for query, results in zip(queries, retrieve_results): + for result in results: + all_pairs.append([query, result["text"]]) + + # ๋ฐฐ์น˜ ์ฒ˜๋ฆฌ + all_scores = [] + for i in tqdm(range(0, len(all_pairs), self.batch_size), desc="Reranking"): + batch_pairs = all_pairs[i : i + self.batch_size] + + with torch.no_grad(): + inputs = self.tokenizer( + batch_pairs, + padding=True, + truncation=True, + return_tensors="pt", + max_length=self.max_length, + ) + inputs = {k: v.to(self.device) for k, v in inputs.items()} + batch_scores = self.model(**inputs, return_dict=True).logits.view(-1).float().cpu().numpy() + all_scores.extend(batch_scores) + + all_scores = np.array(all_scores) + + reranked_results = [] + start = 0 + for results in retrieve_results: + end = start + len(results) + scores = all_scores[start:end] + scores = self._exp_normalize(scores.reshape(1, -1)).flatten() + top_indices = np.argsort(scores)[-topk:][::-1] + + reranked_batch = [{"text": results[i]["text"], "score": float(scores[i])} for i in top_indices] + reranked_results.append(reranked_batch) + start = end + + return reranked_results + + +if __name__ == "__main__": + config_folder = os.path.join(os.path.dirname(__file__), "..", "..", "config") + load_dotenv(os.path.join(config_folder, ".env")) + + reranker = Reranker( + model_path="Dongjin-kr/ko-reranker", + batch_size=128, + max_length=512, + ) + retriever = ElasticsearchRetriever( + index_name="two-wiki-index", + ) + + query = "์„ ๋น„๋“ค ์ˆ˜๋งŒ ๋ช…์ด ๋Œ€๊ถ ์•ž์— ๋ชจ์—ฌ ๋งŒ ๋™๋ฌ˜์™€ ์„œ์›์„ ๋‹ค์‹œ ์„ค๋ฆฝํ•  ๊ฒƒ์„ ์ฒญํ•˜๋‹ˆ, (๊ฐ€)์ด/๊ฐ€ ํฌ๊ฒŒ ๋…ธํ•˜์—ฌ ํ•œ์„ฑ๋ถ€์˜ ์กฐ๋ก€(็š‚้šท)์™€ ๋ณ‘์กธ๋กœ ํ•˜์—ฌ ๊ธˆ ํ•œ ๊ฐ• ๋ฐ–์œผ๋กœ ๋ชฐ์•„๋‚ด๊ฒŒ ํ•˜๊ณ  ๋“œ๋””์–ด ์ฒœ์—ฌ ๊ณณ์˜ ์„œ์›์„ ์ฒ ํํ•˜๊ณ  ๊ทธ ํ† ์ง€๋ฅผ ๋ชฐ์ˆ˜ํ•˜์—ฌ ๊ด€์— ์†ํ•˜๊ฒŒ ํ•˜์˜€๋‹ค.๏ผ๋Œ€ํ•œ๊ณ„๋…„์‚ฌ" # noqa: E501 + retriever_result = retriever.retrieve(query, top_k=5) + logger.debug("Elasticsearch Retriever") + logger.debug(f"{retriever_result[:5]}") + + reranked_results = reranker.rerank(queries=[query], retrieve_results=[retriever_result], topk=3) + logger.debug("Reranker") + logger.debug(f"{reranked_results}") diff --git a/code/rag/retriever.py b/code/rag/retriever.py new file mode 100644 index 0000000..b32d5b6 --- /dev/null +++ b/code/rag/retriever.py @@ -0,0 +1,187 @@ +from collections import defaultdict +from glob import glob +import math +import os +import pickle +import typing + +# from utils import get_passage_file +from typing import List + +from dpr_data import KorQuadDataset, KorQuadSampler, korquad_collator +from encoder import KobertBiEncoder +from indexers import DenseFlatIndexer +from loguru import logger +import torch +from torch import tensor as T +from tqdm import tqdm + + +def get_wiki_filepath(data_dir): + return glob(f"{data_dir}/*/wiki_*") + + +def wiki_worker_init(worker_id): + worker_info = torch.utils.data.get_worker_info() + dataset = worker_info.dataset + # logger.debug(dataset) + # dataset = + overall_start = dataset.start + overall_end = dataset.end + split_size = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers))) + worker_id = worker_info.id + # end_idx = min((worker_id+1) * split_size, len(dataset.data)) + dataset.start = overall_start + worker_id * split_size + dataset.end = min(dataset.start + split_size, overall_end) # index error ๋ฐฉ์ง€ + + +def get_passage_file(p_id_list: typing.List[int]) -> str: + """passage id๋ฅผ ๋ฐ›์•„์„œ ํ•ด๋‹น๋˜๋Š” ํŒŒ์ผ ์ด๋ฆ„์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.""" + target_file = None + p_id_max = max(p_id_list) + p_id_min = min(p_id_list) + + # ํ˜„์žฌ ํŒŒ์ผ์˜ ๊ฒฝ๋กœ๋ฅผ ๊ธฐ์ค€์œผ๋กœ 'processed_passages' ๊ฒฝ๋กœ ์„ค์ • + current_dir = os.path.dirname(os.path.abspath(__file__)) + passages_dir = os.path.join(current_dir, "processed_passages") + + # 'processed_passages' ๋””๋ ‰ํ„ฐ๋ฆฌ์—์„œ ํŒŒ์ผ์„ ์ฐพ์Œ + for f in glob(f"{passages_dir}/*.p"): + file_name = os.path.basename(f) + s, e = file_name.split(".")[0].split("-") + s, e = int(s), int(e) + + if p_id_min >= s and p_id_max <= e: + target_file = f + break + + if target_file is None: + logger.debug(f"No file found for passage IDs: {p_id_list}") + + return target_file + + +class KorDPRRetriever: + def __init__(self, model, valid_dataset, index, val_batch_size: int = 64, device="cuda:0"): + # ๋ชจ๋ธ์ด ๊ฒฝ๋กœ๋กœ ์ฃผ์–ด์ง„ ๊ฒฝ์šฐ ๋กœ๋“œ + if isinstance(model, str): + self.model = KobertBiEncoder() + self.model.load(model) + else: + self.model = model + + # ๋ชจ๋ธ์„ ํ•ด๋‹น ๋””๋ฐ”์ด์Šค๋กœ ์ด๋™ + self.model = self.model.to(device) + self.model.eval() + + # ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ + self.valid_dataset = valid_dataset + + # ์ธ๋ฑ์Šค๊ฐ€ ๊ฒฝ๋กœ๋กœ ์ฃผ์–ด์ง„ ๊ฒฝ์šฐ ๋กœ๋“œ + if isinstance(index, str): + self.index = DenseFlatIndexer() + self.index.deserialize(path=index) + else: + self.index = index + self.model = model.to(device) + self.device = device + self.tokenizer = valid_dataset.tokenizer + self.val_batch_size = val_batch_size + self.valid_loader = torch.utils.data.DataLoader( + dataset=valid_dataset.dataset, + batch_sampler=KorQuadSampler(valid_dataset.dataset, batch_size=val_batch_size, drop_last=False), + collate_fn=lambda x: korquad_collator(x, padding_value=valid_dataset.pad_token_id), + num_workers=4, + ) + self.index = index + + def val_top_k_acc(self, k: List[int] = [5] + list(range(10, 101, 10))): + """validation set์—์„œ top k ์ •ํ™•๋„๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.""" + + self.model.eval() # ํ‰๊ฐ€ ๋ชจ๋“œ + k_max = max(k) + sample_cnt = 0 + retr_cnt = defaultdict(int) + with torch.no_grad(): + for batch in tqdm(self.valid_loader, desc="valid"): + # batch_q, batch_q_attn_mask, batch_p_id, batch_p, batch_p_attn_mask + q, q_mask, p_id, a, a_mask = batch + q, q_mask = ( + q.to(self.device), + q_mask.to(self.device), + ) + q_emb = self.model(q, q_mask, "query") # bsz x bert_dim + result = self.index.search_knn(query_vectors=q_emb.cpu().numpy(), top_docs=k_max) + + for (pred_idx_lst, _), true_idx, _a, _a_mask in zip(result, p_id, a, a_mask): + a_len = _a_mask.sum() + _a = _a[:a_len] + _a = _a[1:-1] + _a_txt = self.tokenizer.decode(_a).strip() + docs = [pickle.load(open(get_passage_file([idx]), "rb"))[idx] for idx in pred_idx_lst] + + for _k in k: + if _a_txt in " ".join(docs[:_k]): + retr_cnt[_k] += 1 + + bsz = q.size(0) + sample_cnt += bsz + retr_acc = {_k: float(v) / float(sample_cnt) for _k, v in retr_cnt.items()} + return retr_acc + + def retrieve(self, query: str, k: int = 100): + """์ฃผ์–ด์ง„ ์ฟผ๋ฆฌ์— ๋Œ€ํ•ด ๊ฐ€์žฅ ์œ ์‚ฌ๋„๊ฐ€ ๋†’์€ passage๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.""" + self.model.eval() # ํ‰๊ฐ€ ๋ชจ๋“œ + tok = self.tokenizer.batch_encode_plus([query], truncation=True, padding=True, max_length=512) + + # Ensure tensors are moved to the same device as the model (cuda:0) + input_ids = T(tok["input_ids"]).to(self.device) + attention_mask = T(tok["attention_mask"]).to(self.device) + + with torch.no_grad(): + out = self.model(input_ids, attention_mask, "query") + + result = self.index.search_knn(query_vectors=out.cpu().numpy(), top_docs=k) + # logger.debug(result) + # ์›๋ฌธ ๊ฐ€์ ธ์˜ค๊ธฐ + passages = [] + for idx, sim in zip(*result[0]): + # logger.debug(idx) + path = get_passage_file([idx]) + if not path: + logger.debug(f"์˜ฌ๋ฐ”๋ฅธ ๊ฒฝ๋กœ์— ํ”ผํดํ™”๋œ ์œ„ํ‚คํ”ผ๋””์•„๊ฐ€ ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”.No single passage path for {idx}") + continue + with open(path, "rb") as f: + passage_dict = pickle.load(f) + # logger.debug(f"passage : {passage_dict[idx]}, sim : {sim}") + passages.append((passage_dict[idx], sim)) + # logger.debug("์„ฑ๊ณต!!!!!!") + return passages + + +if __name__ == "__main__": + # parser = argparse.ArgumentParser() + # parser.add_argument("--query", "-q", type=str, required=True) + # parser.add_argument("--k", "-k", type=int, required=True) + # args = parser.parse_args() + + model = KobertBiEncoder() + model.load("./output/my_model.pt") + model.eval() + valid_dataset = KorQuadDataset("./data/KorQuAD_v1.0_dev.json") + index = DenseFlatIndexer() + index.deserialize(path="./2050iter_flat") + + retriever = KorDPRRetriever(model=model, valid_dataset=valid_dataset, index=index) + + # 'query'์™€ 'k' ๊ฐ’์„ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค. + query = "(๊ฐ€)์ด/๊ฐ€ ํฌ๊ฒŒ ๋…ธํ•˜์—ฌ ํ•œ์„ฑ๋ถ€์˜ ์กฐ๋ก€(็š‚้šท)์™€ ๋ณ‘์กธ๋กœ ํ•˜์—ฌ ๊ธˆ ํ•œ ๊ฐ• ๋ฐ–์œผ๋กœ ๋ชฐ์•„๋‚ด๊ฒŒ ํ•˜๊ณ  ๋“œ๋””์–ด ์ฒœ์—ฌ ๊ณณ์˜ ์„œ์›์„ ์ฒ ํํ•˜๊ณ  ๊ทธ ํ† ์ง€๋ฅผ ๋ชฐ์ˆ˜ํ•˜์—ฌ ๊ด€์— ์†ํ•˜๊ฒŒ ํ•˜์˜€๋‹ค.๏ผ๋Œ€ํ•œ๊ณ„๋…„์‚ฌ ๏ผ" # noqa: E501 + # query = "์„ฑํ•™์ง‘์š”์˜ ์ €์ž๋Š”?" + k = 10 # ์ƒ์œ„ 20๊ฐœ ์œ ์‚ฌํ•œ passage๋ฅผ ์ถœ๋ ฅํ•˜๋ ค๋ฉด k๋ฅผ 20์œผ๋กœ ์„ค์ • + + # retrieve ๋ฉ”์„œ๋“œ๋ฅผ ํ˜ธ์ถœํ•˜์—ฌ ๊ฐ€์žฅ ์œ ์‚ฌ๋„๊ฐ€ ๋†’์€ k๊ฐœ์˜ passage๋ฅผ ์ฐพ์Šต๋‹ˆ๋‹ค. + passages = retriever.retrieve(query=query, k=k) + + # ์ถœ๋ ฅ: ์œ ์‚ฌ๋„ ๋†’์€ passage์™€ ๊ทธ ์œ ์‚ฌ๋„๋ฅผ ์ถœ๋ ฅํ•ฉ๋‹ˆ๋‹ค. + for idx, (passage, sim) in enumerate(passages): + logger.debug(f"Rank {idx + 1} | Similarity: {sim:.4f} | Passage: {passage}") diff --git a/code/rag/retriever_bm25.py b/code/rag/retriever_bm25.py new file mode 100644 index 0000000..532fcb9 --- /dev/null +++ b/code/rag/retriever_bm25.py @@ -0,0 +1,140 @@ +import json +import os +import pickle +from typing import Dict, List, Optional + +from loguru import logger +import numpy as np +from rank_bm25 import BM25Okapi + + +# from konlpy.tag import Okt +# okt = Okt() +# def okt_specific_pos_tokenizer(text, stem=True, norm=True): +# # pos ํƒœ๊น… ์ˆ˜ํ–‰ +# pos_tagged = okt.pos(text, stem=stem, norm=norm) +# # ๋ช…์‚ฌ(Noun), ํ˜•์šฉ์‚ฌ(Adjective), ๋™์‚ฌ(Verb)๋งŒ ํ•„ํ„ฐ๋ง +# filtered_words = [word for word, pos in pos_tagged if pos in ["Noun", "Adjective", "Verb"]] +# return filtered_words + + +# Deprecated: ๋„ˆ๋ฌด ๋А๋ ค์„œ ๋” ์ด์ƒ ์‚ฌ์šฉํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. +class BM25Retriever: + def __init__( + self, + tokenize_fn=None, + data_path: Optional[str] = "../data/", + pickle_filename: str = "wiki_bm25.pkl", + doc_filename: Optional[str] = "wiki_document.json", + ) -> None: + self.tokenize_fn = tokenize_fn if tokenize_fn else lambda x: x.split() + self.pickle_path = os.path.join(data_path, pickle_filename) + self.bm25 = None + self.corpus = [] + + # ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ + self._load_dataset(os.path.join(data_path, doc_filename)) + + # ๊ธฐ์กด ์ธ๋ฑ์Šค ๋กœ๋“œ + if os.path.exists(self.pickle_path): + self._load_pickle() + return + + # ์ธ๋ฑ์Šค ์ƒ์„ฑ + self._initialize_retriever() + + def _load_dataset(self, json_path): + logger.info("๋ฌธ์„œ ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ") + with open(json_path, "r", encoding="utf-8") as f: + docs = json.load(f) + self.corpus = [f"{doc['title']}: {doc['text']}" for doc in docs] + + def _load_pickle(self): + logger.info("๊ธฐ์กด BM25 ์ธ๋ฑ์Šค ๋กœ๋“œ") + with open(self.pickle_path, "rb") as f: + data = pickle.load(f) + self.bm25 = data["bm25"] + + def _initialize_retriever(self): + logger.info("์ƒˆ๋กœ์šด BM25 ์ธ๋ฑ์Šค ์ƒ์„ฑ") + + tokenized_corpus = [self.tokenize_fn(doc) for doc in self.corpus] + self.bm25 = BM25Okapi(tokenized_corpus) + + with open(self.pickle_path, "wb") as f: + pickle.dump( + { + "bm25": self.bm25, + }, + f, + ) + logger.info("์ธ๋ฑ์Šค ์ƒ์„ฑ ์™„๋ฃŒ") + + def retrieve(self, query: str, top_k: int = 3) -> List[Dict]: + """ + ์ฃผ์–ด์ง„ ์ฟผ๋ฆฌ์— ๋Œ€ํ•ด ์ƒ์œ„ k๊ฐœ์˜ ๋ฌธ์„œ๋ฅผ ๊ฒ€์ƒ‰ํ•ฉ๋‹ˆ๋‹ค. + """ + if not self.bm25: + raise Exception("BM25 ๋ชจ๋ธ์ด ์ดˆ๊ธฐํ™”๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.") + + tokenized_query = self.tokenize_fn(query) + doc_scores = self.bm25.get_scores(tokenized_query) + top_indices = np.argsort(doc_scores)[-top_k:][::-1] + + results = [] + for idx in top_indices: + results.append( + { + "text": self.corpus[idx], + "score": float(doc_scores[idx]), + } + ) + return results + + def bulk_retrieve(self, queries: List[str], top_k: int = 3) -> List[List[Dict]]: + """ + ์—ฌ๋Ÿฌ ์ฟผ๋ฆฌ์— ๋Œ€ํ•ด ์ผ๊ด„์ ์œผ๋กœ ๊ฒ€์ƒ‰์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค. + """ + if not self.bm25: + raise Exception("BM25 ๋ชจ๋ธ์ด ์ดˆ๊ธฐํ™”๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.") + + results = [] + logger.info(f"{len(queries)}๊ฐœ ์ฟผ๋ฆฌ ์ผ๊ด„ ๊ฒ€์ƒ‰") + # ๋ชจ๋“  ์ฟผ๋ฆฌ๋ฅผ ํ•œ ๋ฒˆ์— ํ† ํฌ๋‚˜์ด์ง• + tokenized_queries = [self.tokenize_fn(query) for query in queries] + + # ๊ฐ ์ฟผ๋ฆฌ๋ณ„๋กœ ๊ฒ€์ƒ‰ ์ˆ˜ํ–‰ + for tokenized_query in tokenized_queries: + doc_scores = self.bm25.get_scores(tokenized_query) + top_indices = np.argsort(doc_scores)[-top_k:][::-1] + + query_results = [] + for idx in top_indices: + query_results.append( + { + "text": self.corpus[idx], + "score": float(doc_scores[idx]), + } + ) + results.append(query_results) + + logger.info(f"{len(queries)}๊ฐœ ์ฟผ๋ฆฌ ์ผ๊ด„ ๊ฒ€์ƒ‰ ์™„๋ฃŒ") + return results + + +if __name__ == "__main__": + os.chdir("..") + retriever = BM25Retriever( + tokenize_fn=None, + data_path="../data/", + pickle_filename="wiki_bm25.pkl", + doc_filename="wiki.json", + ) + + query = "์„ ๋น„๋“ค ์ˆ˜๋งŒ ๋ช…์ด ๋Œ€๊ถ ์•ž์— ๋ชจ์—ฌ ๋งŒ ๋™๋ฌ˜์™€ ์„œ์›์„ ๋‹ค์‹œ ์„ค๋ฆฝํ•  ๊ฒƒ์„ ์ฒญํ•˜๋‹ˆ, (๊ฐ€)์ด/๊ฐ€ ํฌ๊ฒŒ ๋…ธํ•˜์—ฌ ํ•œ์„ฑ๋ถ€์˜ ์กฐ๋ก€(็š‚้šท)์™€ ๋ณ‘์กธ๋กœ ํ•˜์—ฌ ๊ธˆ ํ•œ ๊ฐ• ๋ฐ–์œผ๋กœ ๋ชฐ์•„๋‚ด๊ฒŒ ํ•˜๊ณ  ๋“œ๋””์–ด ์ฒœ์—ฌ ๊ณณ์˜ ์„œ์›์„ ์ฒ ํํ•˜๊ณ  ๊ทธ ํ† ์ง€๋ฅผ ๋ชฐ์ˆ˜ํ•˜์—ฌ ๊ด€์— ์†ํ•˜๊ฒŒ ํ•˜์˜€๋‹ค.๏ผ๋Œ€ํ•œ๊ณ„๋…„์‚ฌ" # noqa: E501 + results = retriever.retrieve(query, top_k=5) + + for i, result in enumerate(results, 1): + logger.debug(f"\n๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ {i}") + logger.debug(f"์ ์ˆ˜: {result['score']:.4f}") + logger.debug(f"๋‚ด์šฉ: {result['text'][:200]}...") diff --git a/code/rag/retriever_elastic.py b/code/rag/retriever_elastic.py new file mode 100644 index 0000000..597a4d2 --- /dev/null +++ b/code/rag/retriever_elastic.py @@ -0,0 +1,190 @@ +import json +import os +import re +from typing import Dict, List, Optional +import warnings + +from dotenv import load_dotenv +from elasticsearch import Elasticsearch, ElasticsearchWarning +from loguru import logger + + +# ElasticsearchWarning ๋ฌด์‹œ +warnings.filterwarnings("ignore", category=ElasticsearchWarning) + + +class ElasticsearchRetriever: + def __init__( + self, + index_name: str = "wiki-index", + data_path: Optional[str] = None, + setting_path: Optional[str] = None, + doc_filename: Optional[str] = None, + ) -> None: + self.index_name = index_name + self.client = self._connect_elasticsearch( + os.getenv("ELASTICSEARCH_URL"), os.getenv("ELASTICSEARCH_ID"), os.getenv("ELASTICSEARCH_PW") + ) + + # ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ ๋ฐ ์ธ๋ฑ์Šค ์ดˆ๊ธฐํ™” + if not self.client.indices.exists(index=self.index_name): + if data_path and setting_path and doc_filename: + docs = self._load_dataset(os.path.join(data_path, doc_filename)) + self._initialize_index(setting_path) + self._insert_documents(docs) + else: + raise ValueError(f"์กด์žฌํ•˜์ง€ ์•Š๋Š” ์ธ๋ฑ์Šค: {index_name}") + + def _connect_elasticsearch(self, url: str, id: str, pw: str) -> Elasticsearch: + """ElasticSearch ํด๋ผ์ด์–ธํŠธ ์—ฐ๊ฒฐ""" + es = Elasticsearch( + url, + basic_auth=(id, pw), + request_timeout=30, + max_retries=10, + retry_on_timeout=True, + verify_certs=False, + ) + logger.info(f"Elasticsearch ์—ฐ๊ฒฐ ์ƒํƒœ: {es.ping()}") + return es + + def _load_dataset(self, doc_filename) -> Dict: + """๋ฌธ์„œ ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ""" + with open(doc_filename, "r", encoding="utf-8") as f: + return json.load(f) + + def _initialize_index(self, setting_path) -> None: + """์ธ๋ฑ์Šค ์ƒ์„ฑ ๋ฐ ์„ค์ •""" + with open(setting_path, "r") as f: + setting = json.load(f) + self.client.indices.create(index=self.index_name, body=setting) + logger.info("์ธ๋ฑ์Šค ์ƒ์„ฑ ์™„๋ฃŒ") + + def _delete_index(self): + if not self.client.indices.exists(index=self.index_name): + logger.info("Index doesn't exist.") + return + + self.client.indices.delete(index=self.index_name) + logger.info("Index deletion has been completed") + + def _insert_documents(self, docs) -> None: + """๋ฌธ์„œ ๋ฐ์ดํ„ฐ bulk ์‚ฝ์ž…""" + + def _preprocess(text): + text = re.sub(r"\n", " ", text) + text = re.sub(r"\\n", " ", text) + text = re.sub(r"#", " ", text) + text = re.sub(r"\s+", " ", text).strip() # ๋‘ ๊ฐœ ์ด์ƒ์˜ ์—ฐ์†๋œ ๊ณต๋ฐฑ์„ ํ•˜๋‚˜๋กœ ์น˜ํ™˜ + return text + + bulk_data = [] + for i, doc in enumerate(docs): + # bulk ์ž‘์—…์„ ์œ„ํ•œ ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ + bulk_data.append({"index": {"_index": self.index_name, "_id": i}}) + # ์‹ค์ œ ๋ฌธ์„œ ๋ฐ์ดํ„ฐ + bulk_data.append({"title": doc["title"], "text": _preprocess(doc["text"])}) + + # 1000๊ฐœ ๋‹จ์œ„๋กœ ๋ฒŒํฌ ์‚ฝ์ž… ์ˆ˜ํ–‰ + if (i + 1) % 1000 == 0: + try: + response = self.client.bulk(body=bulk_data) + if response["errors"]: + logger.warning(f"{i+1}๋ฒˆ์งธ ๋ฒŒํฌ ์‚ฝ์ž… ์ค‘ ์ผ๋ถ€ ์˜ค๋ฅ˜ ๋ฐœ์ƒ") + bulk_data = [] # ๋ฒŒํฌ ๋ฐ์ดํ„ฐ ์ดˆ๊ธฐํ™” + logger.info(f"{i+1}๊ฐœ ๋ฌธ์„œ ๋ฒŒํฌ ์‚ฝ์ž… ์™„๋ฃŒ") + except Exception as e: + logger.error(f"๋ฒŒํฌ ์‚ฝ์ž… ์‹คํŒจ (์ธ๋ฑ์Šค: {i}): {e}") + bulk_data = [] # ์˜ค๋ฅ˜ ๋ฐœ์ƒ ์‹œ์—๋„ ๋ฐ์ดํ„ฐ ์ดˆ๊ธฐํ™” + + # ๋‚จ์€ ๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ + if bulk_data: + try: + response = self.client.bulk(body=bulk_data) + if response["errors"]: + logger.warning("๋งˆ์ง€๋ง‰ ๋ฒŒํฌ ์‚ฝ์ž… ์ค‘ ์ผ๋ถ€ ์˜ค๋ฅ˜ ๋ฐœ์ƒ") + except Exception as e: + logger.error(f"๋งˆ์ง€๋ง‰ ๋ฒŒํฌ ์‚ฝ์ž… ์‹คํŒจ: {e}") + + # ์ตœ์ข… ๋ฌธ์„œ ์ˆ˜ ํ™•์ธ + n_records = self.client.count(index=self.index_name)["count"] + logger.info(f"์ด {n_records}๊ฐœ ๋ฌธ์„œ ์‚ฝ์ž… ์™„๋ฃŒ") + + def retrieve(self, query: str, top_k: int = 3) -> List[Dict]: + """๋‹จ์ผ ์ฟผ๋ฆฌ์— ๋Œ€ํ•œ ๊ฒ€์ƒ‰ ์ˆ˜ํ–‰""" + query_body = {"query": {"bool": {"must": [{"match": {"text": query}}]}}} + + response = self.client.search(index=self.index_name, body=query_body, size=top_k) + + results = [] + for hit in response["hits"]["hits"]: + results.append({"text": f"{hit['_source']['title']}: {hit['_source']['text']}", "score": hit["_score"]}) + return results + + def bulk_retrieve(self, queries: List[str], top_k: int = 3) -> List[List[Dict]]: + """์—ฌ๋Ÿฌ ์ฟผ๋ฆฌ์— ๋Œ€ํ•œ ์ผ๊ด„ ๊ฒ€์ƒ‰ ์ˆ˜ํ–‰ (msearch API ์‚ฌ์šฉ)""" + logger.info(f"{len(queries)}๊ฐœ ์ฟผ๋ฆฌ ์ผ๊ด„ ๊ฒ€์ƒ‰") + + # msearch API๋ฅผ ์œ„ํ•œ bulk ์ฟผ๋ฆฌ ์ค€๋น„ + bulk_query = [] + for query in queries: + # ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ๋ผ์ธ + bulk_query.append({"index": self.index_name}) + # ์ฟผ๋ฆฌ ๋ผ์ธ + bulk_query.append({"query": {"bool": {"must": [{"match": {"text": query}}]}}, "size": top_k}) + + try: + # msearch API ํ˜ธ์ถœ + response = self.client.msearch(body=bulk_query) + + # ๊ฒฐ๊ณผ ์ฒ˜๋ฆฌ + results = [] + for response_item in response["responses"]: + query_results = [] + if not response_item.get("error"): + for hit in response_item["hits"]["hits"]: + query_results.append( + {"text": f"{hit['_source']['title']}: {hit['_source']['text']}", "score": hit["_score"]} + ) + results.append(query_results) + + logger.info(f"{len(queries)}๊ฐœ ์ฟผ๋ฆฌ ์ผ๊ด„ ๊ฒ€์ƒ‰ ์™„๋ฃŒ") + return results + + except Exception as e: + logger.error(f"Bulk search ์‹คํŒจ: {e}") + return [[] for _ in queries] # ์—๋Ÿฌ ๋ฐœ์ƒ ์‹œ ๋นˆ ๊ฒฐ๊ณผ ๋ฐ˜ํ™˜ + + +if __name__ == "__main__": + config_folder = os.path.join(os.path.dirname(__file__), "..", "..", "config") + load_dotenv(os.path.join(config_folder, ".env")) + + retriever = ElasticsearchRetriever( + data_path="../data/", + index_name="wiki-index", + setting_path="../config/elastic_setting.json", + doc_filename="wiki.json", + ) + + # ์ƒˆ๋กœ์šด ๋ฌธ์„œ ์ถ”๊ฐ€ ์‚ฝ์ž…์‹œ์—๋งŒ ์‚ฌ์šฉ + if False: + current_count = retriever.client.count(index=retriever.index_name)["count"] + + logger.info("์ƒˆ๋กœ์šด ๋ฌธ์„œ ์ถ”๊ฐ€ ์‹œ์ž‘") + with open("new_wiki.json", "r", encoding="utf-8") as f: + new_docs = json.load(f) + retriever._insert_documents(new_docs) + + # ๋ฌธ์„œ ์ถ”๊ฐ€ ํ™•์ธ + new_count = retriever.client.count(index=retriever.index_name)["count"] + logger.info(f"๋ฌธ์„œ ์ถ”๊ฐ€ ์™„๋ฃŒ: {current_count} -> {new_count} ({new_count-current_count}๊ฐœ ์ถ”๊ฐ€)") + + # ๋ฌธ์„œ ๊ฒ€์ƒ‰ ํ…Œ์ŠคํŠธ + query = "์„ ๋น„๋“ค ์ˆ˜๋งŒ ๋ช…์ด ๋Œ€๊ถ ์•ž์— ๋ชจ์—ฌ ๋งŒ ๋™๋ฌ˜์™€ ์„œ์›์„ ๋‹ค์‹œ ์„ค๋ฆฝํ•  ๊ฒƒ์„ ์ฒญํ•˜๋‹ˆ, (๊ฐ€)์ด/๊ฐ€ ํฌ๊ฒŒ ๋…ธํ•˜์—ฌ ํ•œ์„ฑ๋ถ€์˜ ์กฐ๋ก€(็š‚้šท)์™€ ๋ณ‘์กธ๋กœ ํ•˜์—ฌ ๊ธˆ ํ•œ ๊ฐ• ๋ฐ–์œผ๋กœ ๋ชฐ์•„๋‚ด๊ฒŒ ํ•˜๊ณ  ๋“œ๋””์–ด ์ฒœ์—ฌ ๊ณณ์˜ ์„œ์›์„ ์ฒ ํํ•˜๊ณ  ๊ทธ ํ† ์ง€๋ฅผ ๋ชฐ์ˆ˜ํ•˜์—ฌ ๊ด€์— ์†ํ•˜๊ฒŒ ํ•˜์˜€๋‹ค.๏ผ๋Œ€ํ•œ๊ณ„๋…„์‚ฌ" # noqa: E501 + results = retriever.retrieve(query, top_k=5) + + for i, result in enumerate(results, 1): + logger.debug(f"\n๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ {i}") + logger.debug(f"์ ์ˆ˜: {result['score']:.4f}") + logger.debug(f"๋‚ด์šฉ: {result['text'][:200]}...") diff --git a/code/rag/train.py b/code/rag/train.py new file mode 100644 index 0000000..20a33fe --- /dev/null +++ b/code/rag/train.py @@ -0,0 +1,228 @@ +from copy import deepcopy +import logging +import os +from typing import Tuple + +from dpr_data import KorQuadSampler, korquad_collator +import numpy as np +import torch +from tqdm import tqdm +import transformers +import wandb + + +# Ensure output directory exists +os.makedirs("./output", exist_ok=True) + +# Set up logging +os.makedirs("logs", exist_ok=True) +logging.basicConfig( + filename="logs/log.log", + level=logging.DEBUG, + format="[%(asctime)s | %(funcName)s @ %(pathname)s] %(message)s", +) +logger = logging.getLogger() # get root logger + + +class Trainer: + """Basic trainer""" + + def __init__( + self, + model, + device, + train_dataset, + valid_dataset, + num_epoch: int, + batch_size: int, + lr: float, + betas: Tuple[float], + num_warmup_steps: int, + num_training_steps: int, + valid_every: int, + best_val_ckpt_path: str, + ): + self.model = model.to(device) + self.device = device + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, betas=betas) + self.scheduler = transformers.get_linear_schedule_with_warmup( + self.optimizer, num_warmup_steps, num_training_steps + ) + self.train_loader = torch.utils.data.DataLoader( + dataset=train_dataset.dataset, + batch_sampler=KorQuadSampler(train_dataset.dataset, batch_size=batch_size, drop_last=False), + collate_fn=lambda x: korquad_collator(x, padding_value=train_dataset.pad_token_id), + num_workers=4, + ) + self.valid_loader = torch.utils.data.DataLoader( + dataset=valid_dataset.dataset, + batch_sampler=KorQuadSampler(valid_dataset.dataset, batch_size=batch_size, drop_last=False), + collate_fn=lambda x: korquad_collator(x, padding_value=valid_dataset.pad_token_id), + num_workers=4, + ) + + self.batch_size = batch_size + self.num_epoch = num_epoch + self.valid_every = valid_every + self.lr = lr + self.betas = betas + self.num_warmup_steps = num_warmup_steps + self.num_training_steps = num_training_steps + self.best_val_ckpt_path = best_val_ckpt_path + self.best_val_optim_path = best_val_ckpt_path.split(".pt")[0] + "_optim.pt" + + self.start_ep = 1 + self.start_step = 1 + + def ibn_loss(self, pred: torch.FloatTensor): + """In-batch negative loss calculation.""" + bsz = pred.size(0) + target = torch.arange(bsz).to(self.device) + return torch.nn.functional.cross_entropy(pred, target) + + def batch_acc(self, pred: torch.FloatTensor): + """Batch accuracy calculation.""" + bsz = pred.size(0) + target = torch.arange(bsz) + return (pred.detach().cpu().max(1).indices == target).sum().float() / bsz + + def fit(self): + """Train the model.""" + wandb.init( + project="kordpr", + entity="lucas01", + config={ + "batch_size": self.batch_size, + "lr": self.lr, + "betas": self.betas, + "num_warmup_steps": self.num_warmup_steps, + "num_training_steps": self.num_training_steps, + "valid_every": self.valid_every, + }, + ) + logger.debug("start training") + self.model.train() # Set model to training mode + global_step_cnt = 0 + prev_best = None + for ep in range(self.start_ep, self.num_epoch + 1): + for step, batch in enumerate(tqdm(self.train_loader, desc=f"epoch {ep} batch"), 1): + if ep == self.start_ep and step < self.start_step: + continue # Skip until the saved checkpoint + + self.model.train() # Set model to training mode + global_step_cnt += 1 + q, q_mask, _, p, p_mask = batch + q, q_mask, p, p_mask = ( + q.to(self.device), + q_mask.to(self.device), + p.to(self.device), + p_mask.to(self.device), + ) + q_emb = self.model(q, q_mask, "query") + p_emb = self.model(p, p_mask, "passage") + pred = torch.matmul(q_emb, p_emb.T) + loss = self.ibn_loss(pred) + acc = self.batch_acc(pred) + + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + self.scheduler.step() + log = { + "epoch": ep, + "step": step, + "global_step": global_step_cnt, + "train_step_loss": loss.cpu().item(), + "current_lr": float(self.scheduler.get_last_lr()[0]), + "step_acc": acc, + } + if global_step_cnt % self.valid_every == 0: + eval_dict = self.evaluate() + log.update(eval_dict) + if prev_best is None or eval_dict["valid_loss"] < prev_best: # Save best validation model + self.save_training_state(log) + wandb.log(log) + + def evaluate(self): + """Evaluate the model.""" + self.model.eval() # Set model to evaluation mode + loss_list = [] + sample_cnt = 0 + valid_acc = 0 + with torch.no_grad(): + for batch in self.valid_loader: + q, q_mask, _, p, p_mask = batch + q, q_mask, p, p_mask = ( + q.to(self.device), + q_mask.to(self.device), + p.to(self.device), + p_mask.to(self.device), + ) + q_emb = self.model(q, q_mask, "query") + p_emb = self.model(p, p_mask, "passage") + pred = torch.matmul(q_emb, p_emb.T) + loss = self.ibn_loss(pred) + step_acc = self.batch_acc(pred) + + bsz = q.size(0) + sample_cnt += bsz + valid_acc += step_acc * bsz + loss_list.append(loss.cpu().item() * bsz) + return { + "valid_loss": np.array(loss_list).sum() / float(sample_cnt), + "valid_acc": valid_acc / float(sample_cnt), + } + + def save_training_state(self, log_dict: dict) -> None: + """Save model, optimizer, and other training states.""" + checkpoint_path = os.path.join("./output", self.best_val_ckpt_path) + self.model.checkpoint(checkpoint_path) + training_state = { + "optimizer_state": deepcopy(self.optimizer.state_dict()), + "scheduler_state": deepcopy(self.scheduler.state_dict()), + } + training_state.update(log_dict) + optim_path = os.path.join("./output", self.best_val_optim_path) + torch.save(training_state, optim_path) + logger.debug(f"Saved optimizer/scheduler state into {optim_path}") + + def load_training_state(self) -> None: + """Load model, optimizer, and other training states.""" + checkpoint_path = os.path.join("./output", self.best_val_ckpt_path) + if os.path.exists(checkpoint_path): + self.model.load(checkpoint_path) + optim_path = os.path.join("./output", self.best_val_optim_path) + training_state = torch.load(optim_path) + logger.debug(f"Loaded optimizer/scheduler state from {optim_path}") + self.optimizer.load_state_dict(training_state["optimizer_state"]) + self.scheduler.load_state_dict(training_state["scheduler_state"]) + self.start_ep = training_state["epoch"] + self.start_step = training_state["step"] + logger.debug(f"Resumed training from epoch {self.start_ep} / step {self.start_step}") + else: + logger.debug("No checkpoint found, starting training from scratch.") + + +# if __name__ == "__main__": +# device = torch.device("cuda:0") +# model = KobertBiEncoder() +# train_dataset = KorQuadDataset("./data/KorQuAD_v1.0_train.json") +# valid_dataset = KorQuadDataset("./data/KorQuAD_v1.0_dev.json") +# my_trainer = Trainer( +# model=model, +# device=device, +# train_dataset=train_dataset, +# valid_dataset=valid_dataset, +# num_epoch=1, +# batch_size=128 - 32, +# lr=1e-5, +# betas=(0.9, 0.99), +# num_warmup_steps=1000, +# num_training_steps=100000, +# valid_every=30, +# best_val_ckpt_path="my_model.pt", +# ) +# my_trainer.load_training_state() +# my_trainer.fit() # Start training +# # eval_dict = my_trainer.evaluate() # If you want to evaluate after training +# # print(eval_dict) diff --git a/code/rag/trainer.py b/code/rag/trainer.py new file mode 100644 index 0000000..79502eb --- /dev/null +++ b/code/rag/trainer.py @@ -0,0 +1,257 @@ +import os +import sys + + +# ํ˜„์žฌ ์ฝ”๋“œ๊ฐ€ ์žˆ๋Š” ๋””๋ ‰ํ† ๋ฆฌ ๊ธฐ์ค€์œผ๋กœ ์ƒ์œ„ ๋””๋ ‰ํ† ๋ฆฌ๋ฅผ `sys.path`์— ์ถ”๊ฐ€ +sys.path.append(os.path.abspath(os.path.dirname(__file__))) +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from copy import deepcopy +import logging +from typing import Tuple + +from dpr_data import KorQuadDataset, KorQuadSampler, korquad_collator +from encoder import KobertBiEncoder +import numpy as np +import torch +from tqdm import tqdm +import transformers + + +os.makedirs("logs", exist_ok=True) +logging.basicConfig( + filename="logs/log.log", + level=logging.DEBUG, + format="[%(asctime)s | %(funcName)s @ %(pathname)s] %(message)s", +) +logger = logging.getLogger() # get root logger + + +class Trainer: + """basic trainer""" + + def __init__( + self, + model, + device, + train_dataset, + valid_dataset, + num_epoch: int, + batch_size: int, + lr: float, + betas: Tuple[float], + num_warmup_steps: int, + num_training_steps: int, + valid_every: int, + best_val_ckpt_path: str, + ): + self.model = model.to(device) + self.device = device + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, betas=betas) + self.scheduler = transformers.get_linear_schedule_with_warmup( + self.optimizer, num_warmup_steps, num_training_steps + ) + self.train_loader = torch.utils.data.DataLoader( + dataset=train_dataset.dataset, + batch_sampler=KorQuadSampler(train_dataset.dataset, batch_size=batch_size, drop_last=False), + collate_fn=lambda x: korquad_collator(x, padding_value=train_dataset.pad_token_id), + num_workers=4, + ) + self.valid_loader = torch.utils.data.DataLoader( + dataset=valid_dataset.dataset, + batch_sampler=KorQuadSampler(valid_dataset.dataset, batch_size=batch_size, drop_last=False), + collate_fn=lambda x: korquad_collator(x, padding_value=valid_dataset.pad_token_id), + num_workers=4, + ) + + self.batch_size = batch_size + self.num_epoch = num_epoch + self.valid_every = valid_every + self.lr = lr + self.betas = betas + self.num_warmup_steps = num_warmup_steps + self.num_training_steps = num_training_steps + self.best_val_ckpt_path = best_val_ckpt_path + self.best_val_optim_path = best_val_ckpt_path.split(".pt")[0] + "_optim.pt" + + self.start_ep = 1 + self.start_step = 1 + + def ibn_loss(self, pred: torch.FloatTensor): + """in-batch negative๋ฅผ ํ™œ์šฉํ•œ batch์˜ loss๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค. + pred : bsz x bsz ๋˜๋Š” bsz x bsz*2์˜ logit ๊ฐ’์„ ๊ฐ€์ง. ํ›„์ž๋Š” hard negative๋ฅผ ํฌํ•จํ•˜๋Š” ๊ฒฝ์šฐ. + """ + bsz = pred.size(0) + target = torch.arange(bsz).to(self.device) # ์ฃผ๋Œ€๊ฐ์„ ์ด answer + return torch.nn.functional.cross_entropy(pred, target) + + def batch_acc(self, pred: torch.FloatTensor): + """batch ๋‚ด์˜ accuracy๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.""" + bsz = pred.size(0) + target = torch.arange(bsz) # ์ฃผ๋Œ€๊ฐ์„ ์ด answer + return (pred.detach().cpu().max(1).indices == target).sum().float() / bsz + + def fit(self): + """๋ชจ๋ธ์„ ํ•™์Šตํ•ฉ๋‹ˆ๋‹ค.""" + # wandb.init( + # project="personal", + # entity="gayean01", + # config={ + # "batch_size": self.batch_size, + # "lr": self.lr, + # "betas": self.betas, + # "num_warmup_steps": self.num_warmup_steps, + # "num_training_steps": self.num_training_steps, + # "valid_every": self.valid_every, + # }, + # ) + logger.debug("start training") + self.model.train() # ํ•™์Šต๋ชจ๋“œ + global_step_cnt = 0 + prev_best = None + for ep in range(self.start_ep, self.num_epoch + 1): + for step, batch in enumerate(tqdm(self.train_loader, desc=f"epoch {ep} batch"), 1): + if ep == self.start_ep and step < self.start_step: + continue # ์ค‘๊ฐ„๋ถ€ํ„ฐ ํ•™์Šต์‹œํ‚ค๋Š” ๊ฒฝ์šฐ ํ•ด๋‹น ์ง€์ ๊นŒ์ง€ ๋ณต์› + + self.model.train() # ํ•™์Šต ๋ชจ๋“œ + global_step_cnt += 1 + q, q_mask, _, p, p_mask = batch + q, q_mask, p, p_mask = ( + q.to(self.device), + q_mask.to(self.device), + p.to(self.device), + p_mask.to(self.device), + ) + q_emb = self.model(q, q_mask, "query") # bsz x bert_dim + p_emb = self.model(p, p_mask, "passage") # bsz x bert_dim + pred = torch.matmul(q_emb, p_emb.T) # bsz x bsz + loss = self.ibn_loss(pred) + acc = self.batch_acc(pred) + + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + self.scheduler.step() + log = { + "epoch": ep, + "step": step, + "global_step": global_step_cnt, + "train_step_loss": loss.cpu().item(), + "current_lr": float(self.scheduler.get_last_lr()[0]), # parameter group 1๊ฐœ์ด๋ฏ€๋กœ + "step_acc": acc, + } + if global_step_cnt % self.valid_every == 0: + eval_dict = self.evaluate() + log.update(eval_dict) + if prev_best is None or eval_dict["valid_loss"] < prev_best: # best val loss์ธ ๊ฒฝ์šฐ ์ €์žฅ + # self.model.checkpoint(self.best_val_ckpt_path) + self.save_training_state(log) + # wandb.log(log) + + def evaluate(self): + """๋ชจ๋ธ์„ ํ‰๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.""" + self.model.eval() # ํ‰๊ฐ€ ๋ชจ๋“œ + loss_list = [] + sample_cnt = 0 + valid_acc = 0 + with torch.no_grad(): + for batch in self.valid_loader: + q, q_mask, _, p, p_mask = batch + q, q_mask, p, p_mask = ( + q.to(self.device), + q_mask.to(self.device), + p.to(self.device), + p_mask.to(self.device), + ) + q_emb = self.model(q, q_mask, "query") # bsz x bert_dim + p_emb = self.model(p, p_mask, "passage") # bsz x bert_dim + pred = torch.matmul(q_emb, p_emb.T) # bsz x bsz + loss = self.ibn_loss(pred) + step_acc = self.batch_acc(pred) + + bsz = q.size(0) + sample_cnt += bsz + valid_acc += step_acc * bsz + loss_list.append(loss.cpu().item() * bsz) + valid_loss = np.array(loss_list).sum() / float(sample_cnt) + valid_acc = valid_acc / float(sample_cnt) + + # ์ฝ˜์†”์— ์ถœ๋ ฅ + logger.info(f"Validation Loss: {valid_loss:.4f}, Validation Accuracy: {valid_acc:.4f}") + return { + "valid_loss": np.array(loss_list).sum() / float(sample_cnt), + "valid_acc": valid_acc / float(sample_cnt), + } + + def save_training_state(self, log_dict: dict) -> None: + """๋ชจ๋ธ, optimizer์™€ ๊ธฐํƒ€ ์ •๋ณด๋ฅผ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค""" + self.model.checkpoint(self.best_val_ckpt_path) + training_state = { + "optimizer_state": deepcopy(self.optimizer.state_dict()), + "scheduler_state": deepcopy(self.scheduler.state_dict()), + } + training_state.update(log_dict) + torch.save(training_state, self.best_val_optim_path) + logger.debug(f"saved optimizer/scheduler state into {self.best_val_optim_path}") + + def load_training_state(self) -> None: + """๋ชจ๋ธ, optimizer์™€ ๊ธฐํƒ€ ์ •๋ณด๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค""" + self.model.load(self.best_val_ckpt_path) + training_state = torch.load(self.best_val_optim_path) + logger.debug(f"loaded optimizer/scheduler state from {self.best_val_optim_path}") + self.optimizer.load_state_dict(training_state["optimizer_state"]) + self.scheduler.load_state_dict(training_state["scheduler_state"]) + self.start_ep = training_state["epoch"] + self.start_step = training_state["step"] + logger.debug(f"resume training from epoch {self.start_ep} / step {self.start_step}") + + +# ๋ชจ๋ธ ์กด์žฌ ์—ฌ๋ถ€ ํ™•์ธ ํ•จ์ˆ˜ +def check_if_model_exists(model_path: str): + """๋ชจ๋ธ ์ฒดํฌํฌ์ธํŠธ๊ฐ€ ์กด์žฌํ•˜๋Š”์ง€ ํ™•์ธํ•˜๋Š” ํ•จ์ˆ˜""" + return os.path.exists(model_path) + + +# ๋ฉ”์ธ ์‹คํ–‰ +if __name__ == "__main__": + # ๋ชจ๋ธ ๊ฒฝ๋กœ ์„ค์ • + model_path = "./output/my_model.pt" + + # ๋ชจ๋ธ์ด ์—†์œผ๋ฉด ํ•™์Šต ์‹œ์ž‘ + if not check_if_model_exists(model_path): + logger.info(f"๋ชจ๋ธ '{model_path}'์ด ์กด์žฌํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ํ•™์Šต์„ ์‹œ์ž‘ํ•ฉ๋‹ˆ๋‹ค.") + + # ํ•™์Šต์„ ์œ„ํ•œ ์ค€๋น„ + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + model = KobertBiEncoder() + train_dataset = KorQuadDataset("./data/KorQuAD_v1.0_train.json") + valid_dataset = KorQuadDataset("./data/KorQuAD_v1.0_dev.json") + + # Trainer ๊ฐ์ฒด ์ƒ์„ฑ + my_trainer = Trainer( + model=model, + device=device, + train_dataset=train_dataset, + valid_dataset=valid_dataset, + num_epoch=1, # ํ•™์Šต epoch ์ˆ˜ + batch_size=32, # ๋ฐฐ์น˜ ํฌ๊ธฐ + lr=1e-5, + betas=(0.9, 0.99), + num_warmup_steps=100, + num_training_steps=1000, + valid_every=100, + best_val_ckpt_path=model_path, + ) + + # ํ•™์Šต ์ˆ˜ํ–‰ + my_trainer.fit() + eval_dict = my_trainer.evaluate() + logger.info(eval_dict) + + # ๋ชจ๋ธ ์ €์žฅ ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ ๋ฐ ์ €์žฅ + os.makedirs("output", exist_ok=True) + torch.save(model.state_dict(), model_path) + logger.info(f"ํ•™์Šต ์™„๋ฃŒ. ๋ชจ๋ธ์ด '{model_path}'์— ์ €์žฅ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.") + else: + logger.info(f"๋ชจ๋ธ '{model_path}'์ด ์ด๋ฏธ ์กด์žฌํ•ฉ๋‹ˆ๋‹ค. ํ•™์Šต์„ ๊ฑด๋„ˆ๋œ๋‹ˆ๋‹ค.") diff --git a/code/rag/utils.py b/code/rag/utils.py new file mode 100644 index 0000000..084b1cd --- /dev/null +++ b/code/rag/utils.py @@ -0,0 +1,36 @@ +from glob import glob +import math +import typing + +import torch + + +def get_wiki_filepath(data_dir): + return glob(f"{data_dir}/*/wiki_*") + + +def wiki_worker_init(worker_id): + worker_info = torch.utils.data.get_worker_info() + dataset = worker_info.dataset + # print(dataset) + # dataset = + overall_start = dataset.start + overall_end = dataset.end + split_size = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers))) + worker_id = worker_info.id + # end_idx = min((worker_id+1) * split_size, len(dataset.data)) + dataset.start = overall_start + worker_id * split_size + dataset.end = min(dataset.start + split_size, overall_end) # index error ๋ฐฉ์ง€ + + +def get_passage_file(p_id_list: typing.List[int]) -> str: + """passage id๋ฅผ ๋ฐ›์•„์„œ ํ•ด๋‹น๋˜๋Š” ํŒŒ์ผ ์ด๋ฆ„์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.""" + target_file = None + p_id_max = max(p_id_list) + p_id_min = min(p_id_list) + for f in glob("processed_passages/*.p"): + s, e = f.split("/")[1].split(".")[0].split("-") + s, e = int(s), int(e) + if p_id_min >= s and p_id_max <= e: + target_file = f + return target_file diff --git a/code/split.py b/code/split.py new file mode 100644 index 0000000..361573a --- /dev/null +++ b/code/split.py @@ -0,0 +1,85 @@ +from ast import literal_eval + +from loguru import logger +import pandas as pd + + +def load_data(file_path): + data = pd.read_csv(file_path) + records = [] + for _, row in data.iterrows(): + problems = literal_eval(row["problems"]) + record = { + "id": row["id"], + "paragraph": row["paragraph"], + "question": problems["question"], + "choices": problems["choices"], + "answer": problems.get("answer", None), + } + records.append(record) + logger.debug(records[0]) # ์ฒซ ๋ฒˆ์งธ ๋ ˆ์ฝ”๋“œ ์ถœ๋ ฅ (๋””๋ฒ„๊น…์šฉ) + return data, records + + +def classify_questions(records): + social_keywords = ["ใ„ฑ.", "ใ‰ ", "์œ„์˜", "์œ„ ๊ธ€" "๋‹จ๋ฝ", "๋ณธ๋ฌธ", "๋ฐ‘์ค„ ์นœ", "(๊ฐ€)", "๋‹ค์Œ", "์‹œ๊ธฐ"] + classifications = [] + + for record in records: + question = record["question"] + paragraph = record["paragraph"] + + # ์‚ฌํšŒ ์˜์—ญ ํŒ๋‹จ + contains_social_keywords = any(keyword in question for keyword in social_keywords) + + # ๊ฐ ์„ ํƒ์ง€๊ฐ€ ๋ณธ๋ฌธ์— ํฌํ•จ๋˜์–ด ์žˆ๋Š”์ง€ ํ™•์ธ + choices_found_in_paragraph = {choice: choice in paragraph for choice in record["choices"]} + + # ์ •๋‹ต์ด ํฌํ•จ๋œ ์„ ํƒ์ง€ ์ฐพ๊ธฐ + answer_index = record["answer"] + answer_found_in_paragraph = False + + if answer_index is not None and 0 <= answer_index < len(record["choices"]): + answer = record["choices"][answer_index] # ์ •๋‹ต ์„ ํƒ์ง€ + answer_found_in_paragraph = choices_found_in_paragraph.get(answer, False) + + # if contains_social_keywords and not answer_found_in_paragraph: + # classification = '์‚ฌํšŒ' + if contains_social_keywords: + classification = "์‚ฌํšŒ" + elif not contains_social_keywords and answer_found_in_paragraph: + classification = "๊ตญ์–ด" + else: + classification = "๋ถˆํ™•์‹ค" # ๋‘ ์กฐ๊ฑด ๋ชจ๋‘ ํ•ด๋‹นํ•˜์ง€ ์•Š๊ฑฐ๋‚˜ ๋ชจ๋‘ ํ•ด๋‹นํ•˜๋Š” ๊ฒฝ์šฐ + + classifications.append( + { + "id": record["id"], + "classification": classification, + "contains_social_keywords": contains_social_keywords, + "answer_found_in_paragraph": answer_found_in_paragraph, + "choices_found_in_paragraph": choices_found_in_paragraph, # ์„ ํƒ์ง€ ํฌํ•จ ์—ฌ๋ถ€ ์ถ”๊ฐ€ + } + ) + + return classifications + + +def main(): + file_path = "../data/train.csv" # ํŒŒ์ผ ๊ฒฝ๋กœ ์„ค์ • + data, records = load_data(file_path) + + classifications = classify_questions(records) + + # ๊ฒฐ๊ณผ๋ฅผ ๋ฐ์ดํ„ฐํ”„๋ ˆ์ž„์œผ๋กœ ๋ณ€ํ™˜ + result_df = pd.DataFrame(classifications) + + # ๊ฒฐ๊ณผ๋ฅผ CSV ํŒŒ์ผ๋กœ ์ €์žฅ + output_file_path = "../data/classification_results.csv" # ์ถœ๋ ฅ ํŒŒ์ผ ๊ฒฝ๋กœ ์„ค์ • + result_df.to_csv(output_file_path, index=False, encoding="utf-8-sig") # CSV๋กœ ์ €์žฅ + + logger.debug(f"๊ฒฐ๊ณผ๊ฐ€ {output_file_path}์— ์ €์žฅ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.") + + +if __name__ == "__main__": + main() diff --git a/code/trainer.py b/code/trainer.py new file mode 100644 index 0000000..5304a5c --- /dev/null +++ b/code/trainer.py @@ -0,0 +1,115 @@ +import evaluate +import numpy as np +from peft import LoraConfig +import torch +from transformers import EarlyStoppingCallback +from trl import DataCollatorForCompletionOnlyLM, SFTConfig, SFTTrainer + + +class CustomTrainer: + def __init__(self, training_config, model, tokenizer, train_dataset, eval_dataset): + self.model = model + self.tokenizer = tokenizer + self.train_dataset = train_dataset + self.eval_dataset = eval_dataset + self.training_config = training_config + self.acc_metric = evaluate.load("accuracy") + self.int_output_map = {"1": 0, "2": 1, "3": 2, "4": 3, "5": 4} + + def train(self): + trainer = self._setup_trainer() + trainer.train() + return trainer.model + + def _setup_trainer(self): + # ๋ฐ์ดํ„ฐ ์ฝœ๋ ˆ์ดํ„ฐ ์„ค์ • + data_collator = DataCollatorForCompletionOnlyLM( + response_template=self.training_config["response_template"], + tokenizer=self.tokenizer, + ) + + # LoRA ์„ค์ • + peft_config = LoraConfig( + r=self.training_config["lora"]["r"], + lora_alpha=self.training_config["lora"]["lora_alpha"], + lora_dropout=self.training_config["lora"]["lora_dropout"], + target_modules=self.training_config["lora"]["target_modules"], + bias=self.training_config["lora"]["bias"], + task_type=self.training_config["lora"]["task_type"], + ) + + # SFT ์„ค์ • + sft_config = SFTConfig( + do_train=self.training_config["params"]["do_train"], + do_eval=self.training_config["params"]["do_eval"], + lr_scheduler_type=self.training_config["params"]["lr_scheduler_type"], + max_seq_length=self.training_config["params"]["max_seq_length"], + per_device_train_batch_size=self.training_config["params"]["per_device_train_batch_size"], + per_device_eval_batch_size=self.training_config["params"]["per_device_eval_batch_size"], + gradient_accumulation_steps=self.training_config["params"]["gradient_accumulation_steps"], + gradient_checkpointing=self.training_config["params"]["gradient_checkpointing"], + max_grad_norm=self.training_config["params"]["max_grad_norm"], + num_train_epochs=self.training_config["params"]["num_train_epochs"], + learning_rate=self.training_config["params"]["learning_rate"], + weight_decay=self.training_config["params"]["weight_decay"], + optim=self.training_config["params"]["optim"], + logging_strategy=self.training_config["params"]["logging_strategy"], + save_strategy=self.training_config["params"]["save_strategy"], + eval_strategy=self.training_config["params"]["eval_strategy"], + logging_steps=self.training_config["params"]["logging_steps"], + save_steps=self.training_config["params"]["save_steps"], + eval_steps=self.training_config["params"]["eval_steps"], + save_total_limit=self.training_config["params"]["save_total_limit"], + save_only_model=self.training_config["params"]["save_only_model"], + load_best_model_at_end=self.training_config["params"]["load_best_model_at_end"], + report_to=self.training_config["params"]["report_to"], + run_name=self.training_config["params"]["run_name"], + output_dir=self.training_config["params"]["output_dir"], + overwrite_output_dir=self.training_config["params"]["overwrite_output_dir"], + metric_for_best_model=self.training_config["params"]["metric_for_best_model"], + ) + + return SFTTrainer( + model=self.model, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + data_collator=data_collator, + tokenizer=self.tokenizer, + peft_config=peft_config, + args=sft_config, + compute_metrics=self._compute_metrics, + preprocess_logits_for_metrics=self._preprocess_logits_for_metrics, + callbacks=[ + EarlyStoppingCallback( + early_stopping_patience=self.training_config["params"]["early_stop_patience"], + early_stopping_threshold=self.training_config["params"]["early_stop_threshold"], + ) + ], + ) + + # ๋ชจ๋ธ์˜ logits๋ฅผ ์กฐ์ •ํ•˜์—ฌ ์ •๋‹ต ํ† ํฐ ๋ถ€๋ถ„๋งŒ ์ถœ๋ ฅํ•˜๋„๋ก ์„ค์ • + def _preprocess_logits_for_metrics(self, logits, labels): + logits = logits if not isinstance(logits, tuple) else logits[0] + logit_idx = [ + self.tokenizer.vocab["1"], + self.tokenizer.vocab["2"], + self.tokenizer.vocab["3"], + self.tokenizer.vocab["4"], + self.tokenizer.vocab["5"], + ] + return logits[:, -2, logit_idx] # -2: answer token, -1: eos token + + # metric ๊ณ„์‚ฐ ํ•จ์ˆ˜ + def _compute_metrics(self, evaluation_result): + logits, labels = evaluation_result + # ํ† ํฐํ™”๋œ ๋ ˆ์ด๋ธ” ๋””์ฝ”๋”ฉ + labels = np.where(labels != -100, labels, self.tokenizer.pad_token_id) + labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True) + labels = [x.split("")[0].strip() for x in labels] + labels = [self.int_output_map[x] for x in labels] + + # softmax ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ logits ๋ณ€ํ™˜ + probs = torch.nn.functional.softmax(torch.tensor(logits), dim=-1) + predictions = np.argmax(probs, axis=-1) + + return self.acc_metric.compute(predictions=predictions, references=labels) diff --git a/code/utils/__init__.py b/code/utils/__init__.py new file mode 100644 index 0000000..04618c3 --- /dev/null +++ b/code/utils/__init__.py @@ -0,0 +1,13 @@ +""" +ํ”„๋กœ์ ํŠธ ์ „๋ฐ˜์— ์‚ฌ์šฉํ•˜๋Š” ์œ ํ‹ธ๋ฆฌํ‹ฐ ๋ชจ๋“ˆ์ž…๋‹ˆ๋‹ค. + +## ์ฃผ์š” ๊ธฐ๋Šฅ +- hf_manager.py: ํ—ˆ๊น…ํŽ˜์ด์Šค์— ๋ชจ๋ธ/๋ฐ์ดํ„ฐ ์—…๋กœ๋“œ +- gdrive_manager.py: config & output์„ ๊ตฌ๊ธ€ ๋“œ๋ผ์ด๋ธŒ๋กœ ์ž๋™ ์—…๋กœ๋“œ +- common.py: ์ธ์ž ๋ฐ ๋กœ๊น… ์„ค์ •์„ ์œ„ํ•œ ํ•จ์ˆ˜ ๋ชจ์Œ + +""" + +from .common import create_experiment_filename, load_config, load_env_file, log_config, set_logger, set_seed, timer +from .gdrive_manager import GoogleDriveManager +from .hf_manager import HuggingFaceHubManager diff --git a/code/utils/common.py b/code/utils/common.py new file mode 100644 index 0000000..ac5aa5b --- /dev/null +++ b/code/utils/common.py @@ -0,0 +1,106 @@ +import argparse +from contextlib import contextmanager +from datetime import datetime +import os +import random +import time + +from dotenv import load_dotenv +from loguru import logger +import numpy as np +import torch +import yaml +from zoneinfo import ZoneInfo + + +# ์ฝ”๋“œ ์ „์—ญ์—์„œ ์ฒซ ์‹คํ–‰ ์‹œ์ ์˜ ํƒ€์ž„์Šคํƒฌํ”„๋ฅผ ๋™์ผํ•˜๊ฒŒ ์‚ฌ์šฉ +CURRENT_TIME = None + + +def set_seed(seed=42): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # if use multi-GPU + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + np.random.seed(seed) + random.seed(seed) + + +def load_config(): + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default="config.yaml") + args = parser.parse_args() + + with open(os.path.join("../config", args.config), encoding="utf-8") as f: + config = yaml.safe_load(f) + return config + + +def load_env_file(filepath="../config/.env"): + try: + # .env ํŒŒ์ผ ๋กœ๋“œ ์‹œ๋„ + if load_dotenv(filepath): + logger.debug(f".env ํŒŒ์ผ์„ ์„ฑ๊ณต์ ์œผ๋กœ ๋กœ๋“œํ–ˆ์Šต๋‹ˆ๋‹ค: {filepath}") + else: + raise FileNotFoundError # ํŒŒ์ผ์ด ์—†์œผ๋ฉด ์˜ˆ์™ธ ๋ฐœ์ƒ + except FileNotFoundError: + logger.debug(f"๊ฒฝ๊ณ : ์ง€์ •๋œ .env ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {filepath}") + except Exception as e: + logger.debug(f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: .env ํŒŒ์ผ ๋กœ๋“œ ์ค‘ ์˜ˆ์™ธ๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {e}") + + +def set_logger(log_file="../log/file.log", log_level="DEBUG"): + # ๋กœ๊ฑฐ ์„ค์ • + logger.add( + log_file, + format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}", + level=log_level, + rotation="12:00", # ๋งค์ผ 12์‹œ์— ์ƒˆ๋กœ์šด ๋กœ๊ทธ ํŒŒ์ผ ์ƒ์„ฑ + retention="7 days", # 7์ผ ํ›„ ๋กœ๊ทธ ์ œ๊ฑฐ + ) + + +# config ํ™•์ธ +def log_config(config, depth=0): + if depth == 0: + print("*" * 40) + for k, v in config.items(): + prefix = ["\t" * depth, k, ":"] + + if isinstance(v, dict): + print(*prefix) + log_config(v, depth + 1) + else: + prefix.append(v) + print(*prefix) + if depth == 0: + print("*" * 40) + + +def get_current_time(): + global CURRENT_TIME + if CURRENT_TIME is None: + CURRENT_TIME = datetime.now(ZoneInfo("Asia/Seoul")).strftime("%m%d%H%M") + return CURRENT_TIME + + +def create_experiment_filename(config): + if config is None: + config = load_config() + username = config["exp"]["username"] + base_model = config["model"]["base_model"].replace("/", "_") + train_path = config["data"]["train_path"] + train_name = os.path.splitext(os.path.basename(train_path))[0] + num_train_epochs = config["training"]["params"]["num_train_epochs"] + learning_rate = config["training"]["params"]["learning_rate"] + current_time = get_current_time() + + return f"{username}_{base_model}_{train_name}_{num_train_epochs}_{learning_rate}_{current_time}" + + +@contextmanager +def timer(name): + t0 = time.time() + yield + logger.debug(f"[{name}] done in {time.time() - t0:.3f} s") diff --git a/code/utils/gdrive_manager.py b/code/utils/gdrive_manager.py new file mode 100755 index 0000000..a882613 --- /dev/null +++ b/code/utils/gdrive_manager.py @@ -0,0 +1,196 @@ +import io +import json +import os.path + +from dotenv import load_dotenv +from google.auth.transport.requests import Request +from google.oauth2.credentials import Credentials +from google_auth_oauthlib.flow import InstalledAppFlow +from googleapiclient.discovery import build +from googleapiclient.http import MediaFileUpload, MediaIoBaseUpload +from loguru import logger +import pandas as pd + + +SCOPES = [ + "https://www.googleapis.com/auth/drive.file", + "https://www.googleapis.com/auth/drive", +] + + +class GoogleDriveManager: + def __init__(self): + config_folder = os.path.join(os.path.dirname(__file__), "..", "..", "config") + load_dotenv(os.path.join(config_folder, ".env")) + self.config_folder = config_folder + self.root_folder_id = os.getenv("GDRIVE_FOLDER_ID") + self.credentials = os.getenv("GDRIVE_CREDENTIALS") + self.token = os.getenv("GDRIVE_TOKEN") + self.is_create_token = os.getenv("GDRIVE_CREATE_TOKEN") + + # ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ๊ฒ€์ฆ ์ถ”๊ฐ€ + if not all([self.root_folder_id, self.credentials, self.token]): + logger.error(f"ํ•„์ˆ˜ ํ™˜๊ฒฝ ๋ณ€์ˆ˜๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. {[self.root_folder_id, self.credentials, self.token]}") + raise ValueError("ํ•„์ˆ˜ ํ™˜๊ฒฝ ๋ณ€์ˆ˜๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.") + + self.service = self.get_drive_service() + + def get_drive_service(self): + creds = None + if os.path.exists(self.token): + creds = Credentials.from_authorized_user_file(self.token, SCOPES) + + if not creds or not creds.valid: + if creds and creds.expired and creds.refresh_token: + creds.refresh(Request()) + elif self.is_create_token == "true": + flow = InstalledAppFlow.from_client_secrets_file(self.credentials, scopes=SCOPES) + creds = flow.run_local_server(port=0, open_browser=False) + # ๋ฆฌํ”„๋ ˆ์‹œ ํ† ํฐ ์ €์žฅ + with open(self.token, "w") as token_file: + token_data = { + "refresh_token": creds.refresh_token, + "token": creds.token, + "token_uri": creds.token_uri, + "client_id": creds.client_id, + "client_secret": creds.client_secret, + "scopes": creds.scopes, + } + json.dump(token_data, token_file) + logger.info("๊ตฌ๊ธ€ ๋“œ๋ผ์ด๋ธŒ ํ† ํฐ์ด ๊ฐฑ์‹ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.") + else: + logger.error("๊ตฌ๊ธ€ ๋“œ๋ผ์ด๋ธŒ ์—…๋กœ๋“œ ์‹คํŒจ. ํ† ํฐ์„ ๊ฐฑ์‹ ์„ ์š”์ฒญํ•˜์„ธ์š”.") + with open(self.token, "w") as token: + token.write(creds.to_json()) + return build("drive", "v3", credentials=creds) + + def find_folder_id_by_name(self, folder_name, parent_folder_id=None): + """ํด๋”๋ช…์œผ๋กœ ํด๋” ID ์ฐพ๊ธฐ""" + if not parent_folder_id: + parent_folder_id = self.root_folder_id + + # ํŠน์ • ํด๋”๋ช…๊ณผ ์ •ํ™•ํžˆ ์ผ์น˜ํ•˜๋Š” ํด๋” ๊ฒ€์ƒ‰ ์ฟผ๋ฆฌ + query = f"name='{folder_name}' and " + query += f"'{parent_folder_id}' in parents and " + query += "mimeType='application/vnd.google-apps.folder' and " + query += "trashed=false" + + try: + results = ( + self.service.files() + .list( + q=query, + spaces="drive", + fields="files(id, name)", + pageSize=1, # ์ฒซ ๋ฒˆ์งธ ์ผ์น˜ํ•˜๋Š” ํด๋”๋งŒ ํ•„์š” + ) + .execute() + ) + + files = results.get("files", []) + + if not files: + logger.info(f"ํด๋”๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {folder_name}") + return None + + return files[0]["id"] + + except Exception as e: + logger.info(f"ํด๋” ๊ฒ€์ƒ‰ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}") + return None + + def list_folder_files(self, folder_id=None): + """ํด๋” ๋‚ด ํŒŒ์ผ ๋ชฉ๋ก ์กฐํšŒ""" + if not folder_id: + folder_id = self.root_folder_id + query = f"'{folder_id}' in parents and trashed=false" + + try: + results = ( + self.service.files() + .list( + q=query, + pageSize=100, + fields="nextPageToken, files(id, name, mimeType, modifiedTime, size)", + ) + .execute() + ) + + return results.get("files", []) + except Exception as e: + logger.info(f"Error listing files: {str(e)}") + return [] + + def upload_yaml_file(self, file_path, filename, folder_id=None): + """YAML ํŒŒ์ผ ๊ฒฝ๋กœ๋ฅผ ๋ฐ›์•„์„œ ์—…๋กœ๋“œ""" + try: + # ํŒŒ์ผ ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ์„ค์ • + file_metadata = {"name": filename, "mimeType": "application/x-yaml"} + if folder_id: + file_metadata["parents"] = [folder_id] + + # ๋ฏธ๋””์–ด ๊ฐ์ฒด ์ƒ์„ฑ + media = MediaFileUpload(file_path, mimetype="application/x-yaml", resumable=True) + + # ํŒŒ์ผ ์—…๋กœ๋“œ + file = self.service.files().create(body=file_metadata, media_body=media, fields="id, name").execute() + + logger.debug(f"Successfully uploaded {filename} to Google Drive") + return file + + except FileNotFoundError: + logger.error(f"File not found: {file_path}") + return None + except Exception as e: + logger.error(f"Error uploading YAML file: {str(e)}") + return None + + def upload_dataframe(self, dataframe, filename, folder_id=None): + """Pandas DataFrame ์ง์ ‘ ์—…๋กœ๋“œ""" + try: + # DataFrame์„ CSV ์ŠคํŠธ๋ฆผ์œผ๋กœ ๋ณ€ํ™˜ + buffer = io.StringIO() + dataframe.to_csv(buffer, index=False) + file_stream = io.BytesIO(buffer.getvalue().encode("utf-8")) + + # ํŒŒ์ผ ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ์„ค์ • + file_metadata = {"name": filename, "mimeType": "text/csv"} + if folder_id: + file_metadata["parents"] = [folder_id] + + # ๋ฏธ๋””์–ด ๊ฐ์ฒด ์ƒ์„ฑ + media = MediaIoBaseUpload(file_stream, mimetype="text/csv", resumable=True) + + # ํŒŒ์ผ ์—…๋กœ๋“œ + file = self.service.files().create(body=file_metadata, media_body=media, fields="id, name").execute() + return file + + except Exception as e: + logger.error(f"Error uploading DataFrame: {str(e)}") + return None + + def upload_exp(self, user_name, output_path, config_path=None): + df = pd.read_csv(output_path) + df_basename = os.path.basename(output_path) + + if config_path is None: + config_path = os.path.join(self.config_folder, "config.yaml") + config_basename = df_basename.replace("output.csv", "config.yaml") + + # ์‹คํ—˜์ž๋ช…์œผ๋กœ ํด๋”๋ช… ์ฐพ๊ธฐ + folder_id = self.find_folder_id_by_name(user_name) + _ = self.upload_dataframe(df, df_basename, folder_id) + _ = self.upload_yaml_file(config_path, config_basename, folder_id) + + gdrive_url = os.path.join("https://drive.google.com/drive/folders", folder_id) + logger.info(f"๊ตฌ๊ธ€ ๋“œ๋ผ์ด๋ธŒ์— ์—…๋กœ๋“œ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค: {gdrive_url}") + + +if __name__ == "__main__": + os.chdir("..") + load_dotenv("../config/.env") + drive_manager = GoogleDriveManager() + # ํŒŒ์ผ ๋ชฉ๋ก ์กฐํšŒ + files = drive_manager.list_folder_files() + for file in files: + logger.info(f"Name: {file['name']}, ID: {file['id']}") diff --git a/code/utils/hf_manager.py b/code/utils/hf_manager.py new file mode 100755 index 0000000..976798f --- /dev/null +++ b/code/utils/hf_manager.py @@ -0,0 +1,81 @@ +import os + +from datasets import load_dataset +from dotenv import load_dotenv +from huggingface_hub import HfApi +from loguru import logger +from peft import AutoPeftModelForCausalLM +from transformers import AutoTokenizer + + +class HuggingFaceHubManager: + def __init__(self): + load_dotenv(os.path.join(os.path.dirname(__file__), "..", "..", "config", ".env")) + self.token = os.getenv("HF_TOKEN") + self.organization = os.getenv("HF_TEAM_NAME") + self.project_name = os.getenv("HF_PROJECT_NAME") + + # ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ๊ฒ€์ฆ ์ถ”๊ฐ€ + if not all([self.token, self.organization, self.project_name]): + raise ValueError("ํ•„์ˆ˜ ํ™˜๊ฒฝ ๋ณ€์ˆ˜๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.") + + def upload_model(self, model_name, username, checkpoint_path): + repo_id = f"{model_name}-{username}" + try: + model = AutoPeftModelForCausalLM.from_pretrained( + checkpoint_path, + trust_remote_code=True, + device_map="auto", + ) + tokenizer = AutoTokenizer.from_pretrained(checkpoint_path, trust_remote_code=True) + + model.push_to_hub(repo_id=repo_id, organization=self.organization, use_auth_token=self.token) + tokenizer.push_to_hub(repo_id=repo_id, organization=self.organization, use_auth_token=self.token) + logger.debug(f"your model pushed successfully in {repo_id}, hugging face") + except Exception as e: + logger.debug(f"An error occurred while uploading to Hugging Face: {e}") + + def upload_dataset(self, file_name, private=True): + """ + ํด๋” ๋‚ด ๋ฐ์ดํ„ฐ ํŒŒ์ผ๋“ค์„ Hugging Face Hub์— ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ์—…๋กœ๋“œํ•˜๋Š” ํ•จ์ˆ˜. + + Parameters: + - file_name (str): ์—…๋กœ๋“œํ•  ๋กœ์ปฌ ๋ฐ์ดํ„ฐ ํŒŒ์ผ ์ด๋ฆ„ + - token (str): Hugging Face ์•ก์„ธ์Šค ํ† ํฐ. ์“ฐ๊ธฐ ๊ถŒํ•œ ํ•„์š” + - private (bool): True๋ฉด ๋น„๊ณต๊ฐœ, False๋ฉด ๊ณต๊ฐœ ์„ค์ • + + Returns: + - None + """ + api = HfApi() + repo_id = f"{self.organization}/{self.project_name}-{file_name}" + + # ๋ฆฌํฌ์ง€ํ† ๋ฆฌ ์กด์žฌ ์—ฌ๋ถ€ ํ™•์ธ + try: + api.repo_info(repo_id, repo_type="dataset", token=self.token) + logger.debug(f"'{repo_id}' ๋ฆฌํฌ์ง€ํ† ๋ฆฌ๊ฐ€ ์ด๋ฏธ ์กด์žฌํ•ฉ๋‹ˆ๋‹ค. ๊ธฐ์กด ๋ฆฌํฌ์ง€ํ† ๋ฆฌ์— ๋ฐ์ดํ„ฐ์…‹์„ ์—…๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.") + except Exception: + # ๋ฆฌํฌ์ง€ํ† ๋ฆฌ๊ฐ€ ์—†์œผ๋ฉด ์ƒ์„ฑ + logger.debug(f"'{repo_id}' ๋ฆฌํฌ์ง€ํ† ๋ฆฌ๊ฐ€ ์กด์žฌํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ์ƒˆ๋กœ ์ƒ์„ฑํ•œ ํ›„ ์—…๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.") + api.create_repo(repo_id=repo_id, repo_type="dataset", private=private, token=self.token) + + # ํŒŒ์ผ ๊ฒฝ๋กœ ์„ค์ • + file_path = os.path.join("..", "data", f"{file_name}.csv") + if not os.path.exists(file_path): + logger.debug(f"ํŒŒ์ผ '{file_path}'์ด ์กด์žฌํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.") + return + + # ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ ๋ฐ ์—…๋กœ๋“œ + dataset = load_dataset("csv", data_files={"train": file_path}) + dataset.push_to_hub(repo_id, token=self.token) + logger.debug(f"๋ฐ์ดํ„ฐ์…‹์ด '{repo_id}'์— ์—…๋กœ๋“œ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.") + + +if __name__ == "__main__": + os.chdir("..") + load_dotenv("../config/.env") + logger.debug(f'{os.getenv("UPLOAD_MODEL_NAME")}, {os.getenv("USERNAME")}, {os.getenv("CHECKPOINT_PATH")}') + + hf_manager = HuggingFaceHubManager() + hf_manager.upload_model(os.getenv("UPLOAD_MODEL_NAME"), os.getenv("USERNAME"), os.getenv("CHECKPOINT_PATH")) + # hf_manager.upload_dataset(args.dataname, private=True) diff --git a/config/elastic_setting.json b/config/elastic_setting.json new file mode 100644 index 0000000..4e49153 --- /dev/null +++ b/config/elastic_setting.json @@ -0,0 +1,33 @@ +{ + "settings": { + "analysis": { + "filter": { + "my_shingle": { + "type": "shingle" + } + }, + "analyzer": { + "my_analyzer": { + "type": "custom", + "tokenizer": "nori_tokenizer", + "decompound_mode": "mixed", + "filter": ["my_shingle"] + } + }, + "similarity": { + "my_similarity": { + "type": "BM25" + } + } + } + }, + + "mappings": { + "properties": { + "document_text": { + "type": "text", + "analyzer": "my_analyzer" + } + } + } +} diff --git a/config/sample/config.yaml b/config/sample/config.yaml new file mode 100755 index 0000000..df99b8a --- /dev/null +++ b/config/sample/config.yaml @@ -0,0 +1,95 @@ +data: + train_path: "../data/train.csv" + test_path: "../data/test.csv" + processed_train_path: "../data/train_500_60to1_es.csv" # ๋ฏธ๋ฆฌ ์ „์ฒ˜๋ฆฌํ•œ ๋ฐ์ดํ„ฐ ์‚ฌ์šฉ: ๋น„์›Œ๋‘๋ฉด ๋™์ž‘ํ•˜์ง€ ์•Š์Œ + processed_test_path: "../data/test_500_60to1_es.csv" # ๋ฏธ๋ฆฌ ์ „์ฒ˜๋ฆฌํ•œ ๋ฐ์ดํ„ฐ ์‚ฌ์šฉ: ๋น„์›Œ๋‘๋ฉด ๋™์ž‘ํ•˜์ง€ ์•Š์Œ + max_seq_length: 2048 + test_size: 0.1 + retriever: + retriever_type: "Elasticsearch" # Elasticsearch + query_type: "p" # retrieve ์ฟผ๋ฆฌ ํƒ€์ž…: pqc, pq, pc, p + query_max_length: 500 # retrieve ๋Œ€์ƒ์ด ๋  ์ฟผ๋ฆฌ์˜ ์ตœ๋Œ€ ๊ธธ์ด: 250-500 ๊ถŒ์žฅ + result_max_length: 1500 # retrieve ๊ฒฐ๊ณผ ๋ฌธ์„œ์˜ ์ตœ๋Œ€ ๊ธธ์ด: 1500-2000 ๊ถŒ์žฅ + top_k: 60 # 60~80 + rerank_k: 1 # 0 ์ดํ•˜๋Š” reranker ๋™์ž‘ํ•˜์ง€ ์•Š์Œ + threshold: 0.2 # 0.2 ~ 0.5 + index_name: "two-wiki-index" # wiki-index, two-wiki-index, aihub-news-index + prompt: + start: "์ง€๋ฌธ:\n {paragraph}\n\n์งˆ๋ฌธ:\n {question}\n\n์„ ํƒ์ง€:\n {choices}\n\n" + start_with_plus: "์ง€๋ฌธ:\n {paragraph}\n\n์งˆ๋ฌธ:\n {question}\n\n<๋ณด๊ธฐ>:\n {question_plus}\n\n์„ ํƒ์ง€:\n {choices}\n\n" + mid: "" + mid_with_document: "ํžŒํŠธ:\n {document}\n\n" + end: "1, 2, 3, 4, 5 ์ค‘์— ํ•˜๋‚˜๋ฅผ ์ •๋‹ต์œผ๋กœ ๊ณ ๋ฅด์„ธ์š”.\n์ •๋‹ต:" + end_gen_cot: "1, 2, 3, 4, 5 ์ค‘์— ํ•˜๋‚˜๋ฅผ ์ •๋‹ต์œผ๋กœ ๊ณ ๋ฅด๊ธฐ ์œ„ํ•œ ๊ทผ๊ฑฐ๋ฅผ ์ฐจ๊ทผ์ฐจ๊ทผ ์ƒ๊ฐํ•ด๋ณด์„ธ์š”.\n๊ทผ๊ฑฐ:" + end_with_cot: "1, 2, 3, 4, 5 ์ค‘์— ํ•˜๋‚˜๋ฅผ ์ •๋‹ต์œผ๋กœ ๊ณ ๋ฅด์„ธ์š”.\n{cot}\n์ •๋‹ต:" + +model: + base_model: "beomi/gemma-ko-2b" + model: + torch_dtype: "float16" + low_cpu_mem_usage: true + use_cache: false # gradient_checkpointing์ด true๋ฉด false์—ฌ์•ผํ•จ + quantization: "" # BitsAndBytes, auto + bits: 8 # 8 or 4 + use_double_quant: false + tokenizer: + padding_side: "right" + chat_template: "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ system_message }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ 'user\n' + content + '\nmodel\n' }}{% elif message['role'] == 'assistant' %}{{ content + '\n' }}{% endif %}{% endfor %}" + +training: + response_template: "model" + lora: + r: 6 + lora_alpha: 8 + lora_dropout: 0.05 + target_modules: ["q_proj", "k_proj"] + bias: "none" + task_type: "CAUSAL_LM" + + params: + do_train: true + do_eval: true + lr_scheduler_type: "cosine" + max_seq_length: 2048 + per_device_train_batch_size: 1 + per_device_eval_batch_size: 1 + gradient_accumulation_steps: 1 + gradient_checkpointing: true + max_grad_norm: 0.3 + num_train_epochs: 3 + learning_rate: 2.0e-05 + weight_decay: 0.01 + optim: "adamw_torch" # ์–‘์žํ™”: adamw_bnb_8bit + logging_strategy: "steps" + save_strategy: "steps" + eval_strategy: "steps" + logging_steps: 300 + save_steps: 600 + eval_steps: 300 + save_total_limit: 4 + save_only_model: true + load_best_model_at_end: true # early_stop์„ ์œ„ํ•ด ํ•„์š” + report_to: "wandb" + run_name: "../outputs" # wandb ์„ธํŒ…์ด ์กด์žฌํ•œ๋‹ค๋ฉด ๋™์ ์œผ๋กœ ์ƒ์„ฑ๋ฉ๋‹ˆ๋‹ค. + output_dir: "../outputs" + overwrite_output_dir: true + metric_for_best_model: "accuracy" # early_stop ๊ธฐ์ค€ + early_stop_patience: 2 + early_stop_threshold: 0 + + +inference: + do_test: true + output_path: "../outputs/" + +log: + file: "../log/file.log" + level: "INFO" + +wandb: + project: generation_for_nlp + entity: hidong1015-nlp04 + +exp: + # ์‹คํ—˜์ž [sujin, seongmin, sungjae, gayeon, yeseo, minseo] + username: fubao diff --git a/config/sample/env-sample.txt b/config/sample/env-sample.txt new file mode 100644 index 0000000..dc8689d --- /dev/null +++ b/config/sample/env-sample.txt @@ -0,0 +1,17 @@ +# .env๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ ์‚ฌ์šฉ +HF_TOKEN="" +HF_TEAM_NAME = "paper-company" +HF_PROJECT_NAME = "KSAT" + +GDRIVE_TOKEN="../config/token.json" +GDRIVE_CREDENTIALS="../config/credentials.json" +GDRIVE_FOLDER_ID ="" +GDRIVE_CREATE_TOKEN= "false" + +UPLOAD_MODEL_NAME = "1115-fubao-exaone3.0-base-v1" # ๋‚ ์งœ-์ด๋ฆ„-๋ฒ ์ด์Šค๋ชจ๋ธ-์‚ฌ์šฉ๋ฐ์ดํ„ฐ์…‹-๋ฒ„์ „ +USERNAME = "fubao" +CHECKPOINT_PATH="../outputs/checkpoint-9999" + +ELASTICSEARCH_URL="http://localhost:9200" +ELASTICSEARCH_ID="" +ELASTICSEARCH_PW="" diff --git a/data_aug/add_CoT.py b/data_aug/add_CoT.py new file mode 100644 index 0000000..417dd85 --- /dev/null +++ b/data_aug/add_CoT.py @@ -0,0 +1,57 @@ +from ast import literal_eval + +import dspy +import pandas as pd +from tqdm import tqdm + + +def process_csv(file_path, output_path, lm_api_key): + lm = dspy.LM("openai/gpt-4o", api_key=lm_api_key) + dspy.configure(lm=lm) + + df = pd.read_csv(file_path) + + records = [] + for _, row in df.iterrows(): + problems = literal_eval(row["problems"]) + record = { + "id": row["id"], + "paragraph": row["paragraph"], + "question": problems["question"], + "choices": problems["choices"], + "answer": problems.get("answer", None), + "question_plus": problems.get("question_plus", None), + } + records.append(record) + + df = pd.DataFrame(records) + + df["steps"] = None + data_list = [] + + for index, row in tqdm(df.iterrows(), total=len(df)): + input_data = { + "paragraph": row["paragraph"], + "question": row["question"], + "choices": row["choices"], + "answer": row["answer"], + } + classify = dspy.ChainOfThought("paragraph: str, question: str, choices: list, answer: str -> steps: list", n=1) + + input_data["question"] = f"{row['question']} ๋‹จ๊ณ„๋ณ„ ์„ค๋ช…(CoT)์„ ์‚ฌ์šฉํ•˜์—ฌ ์˜ฌ๋ฐ”๋ฅธ ๋‹ต์„ ๋„์ถœํ•˜์„ธ์š”." + response = classify(**input_data) + print("response.completions", response.completions) + data_list.append(response.completions) + + df["steps"] = data_list + + df.to_csv(output_path, index=False, encoding="utf-8-sig") + print(f"Updated CSV file saved to: {output_path}") + + +if __name__ == "__main__": + input_file_path = "" + output_file_path = "" + api_key = "" + + process_csv(input_file_path, output_file_path, api_key) diff --git a/data_aug/aug_philo.py b/data_aug/aug_philo.py new file mode 100644 index 0000000..67a210a --- /dev/null +++ b/data_aug/aug_philo.py @@ -0,0 +1,64 @@ +import os +import warnings + +from langchain.prompts import PromptTemplate +from langchain_openai import ChatOpenAI +import pandas as pd + + +def parse_output(output): + lines = output.split("\n") + data_list = [] + data = {"์ง€๋ฌธ": "", "๋ฌธ์ œ": "", "๋ณด๊ธฐ": "", "์ •๋‹ต": "", "ํ•ด์„ค": ""} + current_key = None + + for line in lines: + line = line.strip() + if line.startswith("์ง€๋ฌธ:"): + if data["์ง€๋ฌธ"]: + data_list.append(data.copy()) + data = {"์ง€๋ฌธ": "", "๋ฌธ์ œ": "", "๋ณด๊ธฐ": "", "์ •๋‹ต": "", "ํ•ด์„ค": ""} + current_key = "์ง€๋ฌธ" + data["์ง€๋ฌธ"] = line.replace("์ง€๋ฌธ:", "").strip() + elif line.startswith("๋ฌธ์ œ:"): + current_key = "๋ฌธ์ œ" + data["๋ฌธ์ œ"] = line.replace("๋ฌธ์ œ:", "").strip() + elif line.startswith("๋ณด๊ธฐ:"): + current_key = "๋ณด๊ธฐ" + data["๋ณด๊ธฐ"] = line.replace("๋ณด๊ธฐ:", "").strip() + elif line.startswith("์ •๋‹ต:"): + current_key = "์ •๋‹ต" + data["์ •๋‹ต"] = line.replace("์ •๋‹ต:", "").strip() + elif line.startswith("ํ•ด์„ค:"): + current_key = "ํ•ด์„ค" + data["ํ•ด์„ค"] = line.replace("ํ•ด์„ค:", "").strip() + elif current_key: + data[current_key] += " " + line + + if data["์ง€๋ฌธ"]: + data_list.append(data) + + return data_list + + +if __name__ == "__main__": + warnings.filterwarnings("ignore") + os.environ["OPENAI_API_KEY"] = "" + + # ์‚ฌ์šฉํ•  LLM ๋ชจ๋ธ ์„ค์ • + llm = ChatOpenAI(model_name="gpt-4o-mini", temperature=0.9) + + # ํ”„๋กฌํ”„ํŠธ ํ…œํ”Œ๋ฆฟ ์„ค์ • + prompt_template = """""" + + prompt = PromptTemplate(input_variables=[], template=prompt_template) + + response = llm.invoke(prompt.format()) + output = response.content + + parsed_data_list = parse_output(output) + + df = pd.DataFrame(parsed_data_list) + + csv_filename = "philosophy_questions.csv" + df.to_csv(csv_filename, index=False, encoding="utf-8-sig") diff --git a/data_process/crawling_gichulpass.py b/data_process/crawling_gichulpass.py new file mode 100644 index 0000000..3b846b1 --- /dev/null +++ b/data_process/crawling_gichulpass.py @@ -0,0 +1,159 @@ +import json +import re + +from bs4 import BeautifulSoup +from datasets import load_dataset +import pandas as pd +import requests + + +def answer_symbol_to_int(symbol: str) -> int: + # ์ •๊ทœํ‘œํ˜„์‹์œผ๋กœ ํŠน์ˆ˜๋ฌธ์ž๋งŒ ์ถ”์ถœ + special_chars = re.findall(r"[โ‘ โ‘กโ‘ขโ‘ฃโ‘ค]", symbol) + cleaned_symbol = special_chars[0] if special_chars else symbol + + answer_map = {"โ‘ ": 1, "โ‘ก": 2, "โ‘ข": 3, "โ‘ฃ": 4, "โ‘ค": 5} + return answer_map.get(cleaned_symbol, -1) + + +def extract_question_data(soup, with_table=False): + questions_with_table = [] + questions_without_table = [] + + for item in soup.select("#examList > li"): + # ๋ฌธ์ œ ๋ฒˆํ˜ธ์™€ ์งˆ๋ฌธ + question_element = item.select_one(".pr_problem") + question_text = question_element.get_text(strip=True) + + # ํ…Œ์ด๋ธ” ์กด์žฌ ์—ฌ๋ถ€ ํ™•์ธ + has_table = False + + # ๋ฌธ์ œ์— ํ…Œ์ด๋ธ”์ด ์žˆ๋Š” ๊ฒฝ์šฐ + question_table = question_element.find("table") + if question_table: + has_table = True + question_text = {"text": question_text, "table_html": str(question_table)} + + # ๋ฌธ์ œ ์„ค๋ช… (์žˆ๋Š” ๊ฒฝ์šฐ) + example = item.select_one(".exampleCon") + example_text = example.get_text(strip=True) if example else "" + + # ์˜ˆ์‹œ์— ํ…Œ์ด๋ธ”์ด ์žˆ๋Š” ๊ฒฝ์šฐ + if example: + example_table = example.find("table") + if example_table: + has_table = True + example_text = {"text": example_text, "table_html": str(example_table)} + + # '๊ทธ๋ฆผ' ํฌํ•จ๋œ ๋ฌธ์ œ๋Š” ๊ฑด๋„ˆ๋›ฐ๊ธฐ + if isinstance(question_text, str) and "๊ทธ๋ฆผ" in question_text: + continue + if isinstance(example_text, str) and "๊ทธ๋ฆผ" in example_text: + continue + + # ์„ ํƒ์ง€๋ฅผ ๋ฆฌ์ŠคํŠธ๋กœ ์ €์žฅ + choices = [] + for choice in item.select(".questionCon li label"): + choices.append(choice.get_text(strip=True)) + + # ์ •๋‹ต ๋ฒˆํ˜ธ ์ถ”์ถœ + answer_element = item.select_one(".answer_num") + answer_text = answer_element.get_text(strip=True).strip() if answer_element else "" + answer = answer_symbol_to_int(answer_text) + + # ์ •๋‹ต ์„ค๋ช… + explanation = item.select_one(".answer_explan") + explanation_text = explanation.get_text(strip=True) if explanation else "" + + # ๋ฐ์ดํ„ฐ ์ €์žฅ + question_data = { + "question": question_text, + "paragraph": example_text, + "choices": json.dumps(choices, ensure_ascii=False), + "answer": answer, + "answer_explanation": explanation_text, + } + + # ํ…Œ์ด๋ธ” ์œ ๋ฌด์— ๋”ฐ๋ผ ๋‹ค๋ฅธ ๋ฆฌ์ŠคํŠธ์— ์ €์žฅ + if has_table: + questions_with_table.append(question_data) + else: + questions_without_table.append(question_data) + + return pd.DataFrame(questions_without_table) if not with_table else pd.DataFrame(questions_with_table) + + +def crawl_and_save(subject_code): + headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"} + + url = "https://gichulpass.com/bbs/board.php" + + # ํšŒ๊ณ„ํ•™ ๋ฌธ์ œ + if subject_code == 20: + wr_ids = [903, 27, 889, 887, 885, 884, 880] + + # ํ—Œ๋ฒ• ๋ฌธ์ œ + if subject_code == 26: + wr_ids = ( + [1136, 1063, 963, 953, 875, 849, 543, 537, 526, 515, 510, 500, 542, 536, 525, 514, 509] + + [499, 541, 535, 524, 513, 508, 498, 540, 534, 523, 512, 507, 497, 539, 533, 522, 511, 506] + + [496, 538, 532, 521, 505, 495, 531, 520, 504, 494, 530, 519, 503, 493, 529, 518, 502, 492] + + [528, 517, 501, 491, 527, 516, 490] + ) + + # ํ•œ๊ตญ์‚ฌ ๋ฌธ์ œ + if subject_code == 34: + # 9๊ธ‰๋งŒ ํ•„ํ„ฐ๋ง + wr_ids = ( + list(range(808, 814)) + + list(range(224, 243)) + + list(range(264, 269)) + + list(range(288, 298)) + + list(range(305, 326)) + + [16, 841, 870, 897, 897, 912, 962, 1012, 1026, 1053, 1167, 1172, 1173, 1176, 1177, 1184] + + [1205, 1240, 1241, 1242, 1243, 1244, 1245, 1246, 1247, 1257, 1262, 1357, 1367, 1368, 1404, 1405] + ) + + # ์‚ฌํšŒ ๋ฌธ์ œ + if subject_code == 35: + wr_ids = list(range(808, 840)) + [865, 894, 908, 1010, 1027, 1050] + + dfs = [] + for wr_id in wr_ids: + params = {"bo_table": "exam", "wr_id": wr_id, "subject": subject_code} + response = requests.get(url, params=params, headers=headers) + soup = BeautifulSoup(response.text, "html.parser") + df = extract_question_data(soup) + dfs.append(df) + + concated_df = pd.concat(dfs, axis=0) + len_df = len(concated_df) + concated_df.to_csv(f"gichulpass_{subject_code}_{len_df}_raw.csv", index=False, encoding="utf-8") + + +def check_KMMLU(input_file, output_file): + df = pd.read_csv(input_file, encoding="utf-8") + ds = load_dataset("HAERAE-HUB/KMMLU", "Korean-History") + # df์˜ None๊ฐ’์„ ๋นˆ ๋ฌธ์ž์—ด๋กœ ๋ณ€ํ™˜ํ•˜๊ณ  ๋ชจ๋“  ๊ฐ’์„ ๋ฌธ์ž์—ด๋กœ ๋ณ€ํ™˜ + df = df.fillna("") + df = df.astype(str) + + # ๋ชจ๋“  split์˜ question์„ ํ•˜๋‚˜๋กœ ํ•ฉ์น˜๊ธฐ + questions = pd.concat([pd.DataFrame(ds["train"]), pd.DataFrame(ds["dev"]), pd.DataFrame(ds["test"])]) + + # ํฌํ•จ ์—ฌ๋ถ€๋ฅผ ํ™•์ธํ•˜๋Š” ํ•จ์ˆ˜ + def check_inclusion(column): + return column.apply(lambda x: any(x in question for question in questions["question"])) + + # paragraph์™€ question ๊ฐ๊ฐ ํ™•์ธ ํ›„ ๋‘˜ ์ค‘ ํ•˜๋‚˜๋ผ๋„ True๋ฉด True + df["include"] = check_inclusion(df["paragraph"]) | check_inclusion(df["question"]) + + new_df = df[not df["include"]] + new_df.to_csv(output_file, index=False) + + +if __name__ == "__main__": + crawl_and_save(subject_code=20) + crawl_and_save(subject_code=26) + crawl_and_save(subject_code=34) + crawl_and_save(subject_code=35) diff --git a/data_process/external_musr.py b/data_process/external_musr.py new file mode 100644 index 0000000..8167799 --- /dev/null +++ b/data_process/external_musr.py @@ -0,0 +1,63 @@ +from data_process.process_google_translate import TranslationCache, translate_column, translate_list_column +from datasets import load_dataset +from loguru import logger +import pandas as pd +from tqdm import tqdm + + +def dataset_to_pd(data_name): + data = load_dataset(data_name) + dfs = [ + pd.DataFrame(data["murder_mysteries"]), + pd.DataFrame(data["object_placements"]), + pd.DataFrame(data["team_allocation"]), + ] + return pd.concat(dfs, axis=0) + + +def process_data(df): + return pd.DataFrame( + { + "paragraph": df["narrative"], + "question": df["question"], + "choices": df["choices"], + "answer": df["answer_index"].apply(lambda x: x + 1), + } + ) + + +def process_external_datasets(dataset_name, output_filename): + df = dataset_to_pd(dataset_name) + df = process_data(df) + + # ๊ฐœํ–‰๋ฌธ์ž๋ฅผ \n ๋ฌธ์ž์—ด๋กœ ๋ณ€ํ™˜ + for col in df.columns: + if df[col].dtype == "object": # ๋ฌธ์ž์—ด ์ปฌ๋Ÿผ์— ๋Œ€ํ•ด์„œ๋งŒ ์ฒ˜๋ฆฌ + df[col] = df[col].str.replace("\n", "\\n") + + df.to_csv(output_filename, index=False) + + +def translate_df(input_filename, output_filename): + df = pd.read_csv(input_filename) + + logger.info("๋‹จ๋ฝ ๋ฒˆ์—ญ ์ค‘...") + with TranslationCache("paragraph_cache.json") as paragraph_cache: + df["paragraph"] = translate_column(df["paragraph"], paragraph_cache) + + logger.info("์งˆ๋ฌธ ๋ฒˆ์—ญ ์ค‘...") + with TranslationCache("question_cache.json") as question_cache: + df["question"] = translate_column(df["question"], question_cache) + + logger.info("์„ ํƒ์ง€ ๋ฒˆ์—ญ ์ค‘...") + with TranslationCache("choices_cache.json") as choices_cache: + tqdm.pandas(desc="์„ ํƒ์ง€ ๋ฒˆ์—ญ ์ค‘") + df["choices"] = df["choices"].apply(lambda x: translate_list_column(x, choices_cache)) + + df.to_csv(output_filename, index=False) + + +if __name__ == "__main__": + dataset_name = "TAUR-Lab/MuSR" + process_external_datasets(dataset_name, "MuSR_en_raw.csv") + translate_df("MuSR_en_raw.csv", "MuSR_ko_raw.csv") diff --git a/data_process/external_race.py b/data_process/external_race.py new file mode 100644 index 0000000..cd7f01c --- /dev/null +++ b/data_process/external_race.py @@ -0,0 +1,73 @@ +from datasets import load_dataset +import pandas as pd + + +def dataset_to_pd(data_name): + """์ฃผ์–ด์ง„ ๋ฐ์ดํ„ฐ์…‹ ์ด๋ฆ„์œผ๋กœ๋ถ€ํ„ฐ DataFrame์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.""" + dataset = load_dataset(data_name, "high", split="validation") # 'train', 'validation', 'test' ์ค‘ ์„ ํƒ + return pd.DataFrame(dataset) + + +def process_query_data(df): + """DataFrame์„ ์ฒ˜๋ฆฌํ•˜์—ฌ ํ•„์š”ํ•œ ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.""" + + # answer๋ฅผ A, B, C, D ํ˜•์‹์—์„œ 1, 2, 3, 4 ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜ + def convert_answer(answer): + if answer == "A": + return 1 + elif answer == "B": + return 2 + elif answer == "C": + return 3 + elif answer == "D": + return 4 + else: + return None # ์˜ˆ์ƒ์น˜ ๋ชปํ•œ ๊ฐ’์— ๋Œ€ํ•œ ์ฒ˜๋ฆฌ + + df["article"] = ( + df["article"] + .str.replace(",", "") + .str.replace('""', "") + .str.replace(r"\. ", ".") + .str.replace(r'\." ', '."') + .str.replace(r"([.!?]) ", r"\1") + ) + # ๋ฌธ์ œ๋ฅผ ๋ฌธ์ž์—ด๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ DataFrame ์ƒ์„ฑ + problems = df.apply( + lambda row: {"question": row["question"], "choices": row["options"], "answer": convert_answer(row["answer"])}, + axis=1, + ) + + return pd.DataFrame( + { + "id": df["example_id"], # ์˜ˆ์ œ์˜ ID + "paragraph": df["article"], # ์ •๋ฆฌ๋œ article ์‚ฌ์šฉ + "problems": problems.apply(str), # ๋ฌธ์ œ๋ฅผ ๋ฌธ์ž์—ด๋กœ ๋ณ€ํ™˜ + "question_plus": None, # question_plus๊ฐ€ ์›๋ณธ ๋ฐ์ดํ„ฐ์— ์—†๋‹ค๊ณ  ๊ฐ€์ • + } + ) + + +if __name__ == "__main__": + dataset_name = "ehovy/race" # ์‚ฌ์šฉํ•  ๋ฐ์ดํ„ฐ์…‹ ์ด๋ฆ„ + df = dataset_to_pd(dataset_name) # ๋ฐ์ดํ„ฐ์…‹์„ DataFrame์œผ๋กœ ๋ณ€ํ™˜ + processed_df = process_query_data(df) # ๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ + + # ๊ฒฐ๊ณผ๋ฅผ CSV ํŒŒ์ผ๋กœ ์ €์žฅ (๋”ฐ์˜ดํ‘œ ์ฒ˜๋ฆฌ) + processed_df.to_csv("processed_race_dataset.csv", index=False) + + import re + + import pandas as pd + + # CSV ํŒŒ์ผ ์ฝ๊ธฐ + df = pd.read_csv("processed_race_dataset.csv") + + # paragraph ์—ด๋งŒ ์ˆ˜์ • + df["paragraph"] = df["paragraph"].apply(lambda x: re.sub(r"\n", "", x)) + df["paragraph"] = df["paragraph"].apply(lambda x: re.sub(r"\. ", ".", x)) + df["paragraph"] = df["paragraph"].apply(lambda x: re.sub(r'\." ', '."', x)) + df["paragraph"] = df["paragraph"].apply(lambda x: re.sub(r"([.!?]) ", r"\1", x)) + + # ์ˆ˜์ •๋œ DataFrame์„ ๋‹ค์‹œ CSV๋กœ ์ €์žฅ + df.to_csv("modified_file.csv", index=False) diff --git a/data_process/external_sat_gaokao.py b/data_process/external_sat_gaokao.py new file mode 100644 index 0000000..904b285 --- /dev/null +++ b/data_process/external_sat_gaokao.py @@ -0,0 +1,86 @@ +from datasets import load_dataset +import pandas as pd + + +def dataset_to_pd(data_name): + data = load_dataset(data_name) + return pd.DataFrame(data["test"]) + + +def process_query_data(input_df): + def _split_query_data(df): + paragraphs = [] + questions = [] + + for index, row in df.iterrows(): + text = row["query"] + # Paragraph์™€ ๋‚˜๋จธ์ง€ ํ…์ŠคํŠธ ๋ถ„๋ฆฌ + paragraph, rest = text.split("Q:", 1) + # Question๊ณผ Answer Choices ๋ถ„๋ฆฌ + question, choices = rest.split("Answer Choices: ", 1) + + paragraphs.append(paragraph.strip()) + questions.append(question.strip()) + + return pd.DataFrame( + { + "paragraph": paragraphs, + "question": questions, + "choices": df["choices"], + "answer": df["gold"].apply(lambda x: x[0] + 1), # gold ๋ฐฐ์—ด์„ ํ’€๊ณ  +1 + } + ) + + def _split_answer_choices(df): + def process_choices(choices_list): # ๋ฆฌ์ŠคํŠธ ํ˜•ํƒœ๋กœ ์ž…๋ ฅ ๋ฐ›์Œ + new_choices = [] + for choice in choices_list: + new_choice = ( + choice.replace("(A)", "") + .replace("(B)", "") + .replace("(C)", "") + .replace("(D)", "") + .replace("(E)", "") + ) + new_choices.append(new_choice.strip()) + return new_choices + + df["choices"] = df["choices"].apply(process_choices) + return df + + split_df = _split_query_data(input_df) + final_df = _split_answer_choices(split_df) + return final_df + + +def process_and_concat_external_datasets(dataset_names, output_filename): + dfs = [] + for dataset_name in dataset_names: + df = dataset_to_pd(dataset_name) + df = process_query_data(df) + dfs.append(df) + + concated_df = pd.concat(dfs, axis=0) + concated_df.to_csv(output_filename, index=False) + + +def clean_string(text): + # ๋ฌธ์ž์—ด ๋‚ด๋ถ€์˜ ๋ชจ๋“  ํฐ๋”ฐ์˜ดํ‘œ๋ฅผ ์ž‘์€๋”ฐ์˜ดํ‘œ๋กœ ๋ณ€ํ™˜ + text = text.replace('"', "'") + # ์—ฐ์†๋œ ํฐ๋”ฐ์˜ดํ‘œ๋ฅผ ํ•˜๋‚˜๋กœ ๋ณ€ํ™˜ + while "''" in text: + text = text.replace("''", "'") + text = text.strip() # ์•ž๋’ค ๊ณต๋ฐฑ ์ œ๊ฑฐ + return text + + +if __name__ == "__main__": + dataset_names = [ + "dmayhem93/agieval-sat-en", + "dmayhem93/agieval-logiqa-en", + "dmayhem93/agieval-lsat-rc", + "dmayhem93/agieval-lsat-lr", + "dmayhem93/agieval-lsat-ar", + "dmayhem93/agieval-gaokao-english", + ] + process_and_concat_external_datasets(dataset_names, "sat_gaokao_en_raw.csv") diff --git a/data_process/pdf_to_txt.py b/data_process/pdf_to_txt.py new file mode 100644 index 0000000..61889a7 --- /dev/null +++ b/data_process/pdf_to_txt.py @@ -0,0 +1,29 @@ +import os +import re + +from pdfminer.high_level import extract_text + + +def split_text_by_keyword(text, keyword): + sections = re.split(rf"{keyword}", text) + sections = [section.strip() + keyword for section in sections[:-1]] + [sections[-1].strip()] + return sections + + +def save_sections_to_files(sections, output_dir="sections"): + os.makedirs(output_dir, exist_ok=True) + for i, section in enumerate(sections): + file_name = os.path.join(output_dir, f"section_{i+1}.txt") + with open(file_name, "w", encoding="utf-8") as f: + f.write(section) + print(f"Sections saved to '{output_dir}' directory.") + + +if __name__ == "__main__": + pdf_file_path = "./data/test/2025.pdf" + output_dir = "./data/test/sections" + keyword = "๋‹ตํ•˜์‹œ์˜ค" + + text = extract_text(pdf_file_path) + split_text = split_text_by_keyword(text, keyword) + save_sections_to_files(split_text, output_dir=output_dir) diff --git a/data_process/process_balance_choices.py b/data_process/process_balance_choices.py new file mode 100644 index 0000000..4f1fa39 --- /dev/null +++ b/data_process/process_balance_choices.py @@ -0,0 +1,60 @@ +from ast import literal_eval +import random + +import pandas as pd + + +def balance_choices_dataset(file_path): + # CSV ํŒŒ์ผ ์ฝ๊ธฐ + df = pd.read_csv(file_path) + df["problems"] = df["problems"].apply(literal_eval) + + # ์„ ํƒ์ง€์™€ ๋‹ต ๊ต์ฒด ํ•จ์ˆ˜ + def swap_choices_and_answer(df): + for index, row in df.iterrows(): + problems = row["problems"] + choices = problems["choices"] + answer = problems["answer"] + # ์„ ํƒ์ง€ ๋žœ๋ค ์„ž๊ธฐ + shuffled_choices = choices[:] + random.shuffle(shuffled_choices) + # ์ƒˆ๋กœ์šด ๋‹ต ์ธ๋ฑ์Šค ๊ณ„์‚ฐ + new_answer = shuffled_choices.index(choices[answer - 1]) + # ๊ต์ฒด๋œ ์„ ํƒ์ง€์™€ ๋‹ต์œผ๋กœ ์—…๋ฐ์ดํŠธ + problems["choices"] = shuffled_choices + problems["answer"] = new_answer + 1 + df.at[index, "problems"] = problems + return df + + # ์„ ํƒ์ง€์™€ ๋‹ต ๊ต์ฒด ์ ์šฉ + return swap_choices_and_answer(df) + + +def answer_counts(file_path): + df = pd.read_csv(file_path) + df["problems"] = df["problems"].apply(literal_eval) + + records = [] + for idx, row in df.iterrows(): + problems = row["problems"] + record = { + "id": row["id"], + "paragraph": row["paragraph"], + "question": problems["question"], + "choices": problems["choices"], + "answer": problems.get("answer", None), + "question_plus": problems.get("question_plus", None), + } + records.append(record) + + processed_df = pd.DataFrame(records) + print(len(processed_df)) + + print(processed_df["choices"].apply(len).value_counts()) + print(processed_df["answer"].value_counts()) + + +if __name__ == "__main__": + train_balanced = balance_choices_dataset("../data/train.csv") + train_balanced.to_csv("../data/train_balanced.csv", index=False) + answer_counts("../data/train_balanced.csv") diff --git a/data_process/process_formatting.py b/data_process/process_formatting.py new file mode 100644 index 0000000..4dd9f87 --- /dev/null +++ b/data_process/process_formatting.py @@ -0,0 +1,29 @@ +import pandas as pd + + +def formatting(suffix, input_filename, output_filename): + # CSV ํŒŒ์ผ ์ฝ๊ธฐ + df = pd.read_csv(input_filename, encoding="utf-8") + + # ์ƒˆ๋กœ์šด ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜ + new_records = [] + for idx, row in df.iterrows(): + new_record = { + "id": f"external-data-{suffix}{idx + 1}", + "paragraph": "" if pd.isna(row["paragraph"]) else str(row["paragraph"]), + "problems": {"question": row["question"].strip(), "choices": eval(row["choices"]), "answer": row["answer"]}, + } + new_records.append(new_record) + + # ์ƒˆ๋กœ์šด DataFrame ์ƒ์„ฑ + new_df = pd.DataFrame(new_records) + new_df.to_csv(output_filename, index=False) + + +if __name__ == "__main__": + formatting("gichulpass20-", "gichulpass_20_107_raw.csv", "gichulpass_20_107.csv") + formatting("gichulpass26-", "gichulpass_26_1319_raw.csv", "gichulpass_24_1319.csv") + formatting("gichulpass34-", "gichulpass_34_1352_raw.csv", "gichulpass_34_1352.csv") + formatting("gichulpass35-", "gichulpass_35_568_raw.csv", "gichulpass_35_568.csv") + formatting("SAT", "sat_gaokao_ko_raw.csv", "sat_gaokao_ko.csv") + formatting("MuSR", "MuSR_ko_raw.csv", "MuSR_ko.csv") diff --git a/data_process/process_google_translate.py b/data_process/process_google_translate.py new file mode 100644 index 0000000..cf62868 --- /dev/null +++ b/data_process/process_google_translate.py @@ -0,0 +1,80 @@ +from ast import literal_eval +import json +import os +import time + +from googletrans import Translator +from loguru import logger +from tqdm import tqdm + + +class TranslationCache: + def __init__(self, cache_file="translation_cache.json"): + self.cache_file = cache_file + self.cache = {} + + def __enter__(self): + if os.path.exists(self.cache_file): + with open(self.cache_file, "r", encoding="utf-8") as f: + self.cache = json.load(f) + return self + + def __exit__(self, exc_type, exc_value, traceback): + with open(self.cache_file, "w", encoding="utf-8") as f: + json.dump(self.cache, f, ensure_ascii=False, indent=2) + self.cache.clear() # ๋ฉ”๋ชจ๋ฆฌ ํ•ด์ œ + + def get_translation(self, text): + return self.cache.get(text) + + def add_translation(self, text, translation): + self.cache[text] = translation + + +def translate_with_retry(translator, text, cache, max_retries=3): + # ์บ์‹œ์—์„œ ๋ฒˆ์—ญ ํ™•์ธ + cached_translation = cache.get_translation(text) + if cached_translation: + return cached_translation + + # ์บ์‹œ์— ์—†๋Š” ๊ฒฝ์šฐ ๋ฒˆ์—ญ ์ˆ˜ํ–‰ + for attempt in range(max_retries): + try: + translated = translator.translate(text, src="en", dest="ko").text + time.sleep(0.01) + cache.add_translation(text, translated) + return translated + except Exception as e: + if attempt == max_retries - 1: + logger.info(f"๋ฒˆ์—ญ ์‹คํŒจ: {str(e)}") + return text + time.sleep(1) + + +def translate_list_column(text, cache): + items = literal_eval(text) + translator = Translator() + translated_items = [] + + for item in items: + translated = translate_with_retry(translator, item, cache) + translated_items.append(translated) + return translated_items + + +def translate_column(texts, cache): + translator = Translator() + translated_texts = [] + + # ์ค‘๋ณต ์ œ๊ฑฐ๋ฅผ ์œ„ํ•ด ์œ ๋‹ˆํฌํ•œ ํ…์ŠคํŠธ๋งŒ ์ถ”์ถœ + unique_texts = list(set(texts)) + + for text in tqdm(unique_texts, desc="๊ณ ์œ  ํ…์ŠคํŠธ ๋ฒˆ์—ญ ์ค‘"): + translated = translate_with_retry(translator, text, cache) + + # ์›๋ณธ ์ˆœ์„œ๋Œ€๋กœ ์บ์‹œ์—์„œ ๋ฒˆ์—ญ ๊ฐ€์ ธ์˜ค๊ธฐ + for text in texts: + translated = cache.get_translation(text) + translated_texts.append(translated) + + return translated_texts diff --git a/data_viz/csv2pdf.py b/data_viz/csv2pdf.py new file mode 100644 index 0000000..c08e0fb --- /dev/null +++ b/data_viz/csv2pdf.py @@ -0,0 +1,132 @@ +import argparse +import ast +import os +import sys +import textwrap + +import pandas as pd +from reportlab.lib.pagesizes import A4 +from reportlab.lib.units import cm +from reportlab.pdfbase import pdfmetrics +from reportlab.pdfbase.ttfonts import TTFont +from reportlab.pdfgen import canvas + + +def draw_wrapped_text(c, text, x, y, max_width, max_height, line_height=14): + """ํ…์ŠคํŠธ๋ฅผ ํŽ˜์ด์ง€ ๋„ˆ๋น„์— ๋งž๊ฒŒ ์ค„๋ฐ”๊ฟˆํ•˜์—ฌ ์ถœ๋ ฅํ•˜๋ฉฐ, ํŽ˜์ด์ง€ ๋†’์ด๋ฅผ ์ดˆ๊ณผํ•˜๋ฉด ํŽ˜์ด์ง€๋ฅผ ๋„˜๊น€.""" + wrapped_text = textwrap.fill(text, width=70) + text_obj = c.beginText(x, y) + text_obj.setFont("NanumGothic", 10) + text_obj.setLeading(line_height) + + for line in wrapped_text.splitlines(): + if text_obj.getY() < max_height: + c.drawText(text_obj) + c.showPage() + text_obj = c.beginText(x, A4[1] - 3 * cm) + text_obj.setFont("NanumGothic", 10) + text_obj.setLeading(line_height) + text_obj.textLine(line) + + c.drawText(text_obj) + + +def create_csat_style_pdf(data, filename): + c = canvas.Canvas(filename, pagesize=A4) + width, height = A4 + + for index, row in data.iterrows(): + try: + # ๋ฌธ์ œ ID ์ถœ๋ ฅ + c.setFont("NanumGothic", 12) + c.drawString(1 * cm, height - 1 * cm, f"๋ฌธ์ œ ID: {row['id']}") + + # ๋ณธ๋ฌธ ์ถœ๋ ฅ + paragraph = row["paragraph"] + draw_wrapped_text(c, paragraph, 1 * cm, height - 2 * cm, max_width=85, max_height=14 * cm) + + # ๋ฌธ์ œ ์ถœ๋ ฅ + problem_data = ast.literal_eval(row["problems"]) + question = problem_data["question"] + draw_wrapped_text( + c, + f"๋ฌธ์ œ: {question}", + 1 * cm, + height - 16 * cm, + max_width=85, + max_height=8 * cm, + ) + + # ์„ ํƒ์ง€ ์ถœ๋ ฅ + choices = problem_data["choices"] + choice_y = height - 19 * cm + for i, choice in enumerate(choices, 1): + draw_wrapped_text( + c, + f"{i}. {choice}", + 1 * cm, + choice_y, + max_width=85, + max_height=3 * cm, + ) + choice_y -= 1.2 * cm + + # ์ •๋‹ต ํ‘œ์‹œ + answer = problem_data["answer"] + draw_wrapped_text( + c, + f"์ •๋‹ต: {answer}", + 1 * cm, + choice_y - 1 * cm, + max_width=85, + max_height=3 * cm, + ) + + c.showPage() + except KeyError as ke: + print(f"Data format error: ํ•„์ˆ˜ ํ‚ค๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค. {ke}") + + # PDF ์ €์žฅ + c.save() + + +if __name__ == "__main__": + # ์ธ์ž ํŒŒ์„œ ์„ค์ • + parser = argparse.ArgumentParser(description="Generate a CSAT style PDF from CSV data.") + parser.add_argument( + "--csv_path", + default="../data/train.csv", + help="Path to the CSV file containing the data.", + ) + args = parser.parse_args() + + # CSV ํŒŒ์ผ ์ฝ๊ธฐ ๋ฐ ์ปฌ๋Ÿผ ํ™•์ธ + try: + df = pd.read_csv(args.csv_path) + required_columns = {"id", "paragraph", "problems"} + if not required_columns.issubset(df.columns): + missing_columns = required_columns - set(df.columns) + raise ValueError(f"CSV ํŒŒ์ผ์— ํ•„์š”ํ•œ ์ปฌ๋Ÿผ์ด ์—†์Šต๋‹ˆ๋‹ค: {', '.join(missing_columns)}") + except FileNotFoundError: + print(f"CSV ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {args.csv_path}") + sys.exit(1) + except ValueError as ve: + print(ve) + sys.exit(1) + + # ํ•œ๊ธ€ ํฐํŠธ ๋“ฑ๋ก + font_path = os.path.abspath("../data/NanumGothic.ttf") + if not os.path.isfile(font_path): + print( + f"ํฐํŠธ ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {font_path}" + f"https://hangeul.naver.com/fonts/search?f=nanum ์—์„œ ํฐํŠธ๋ฅผ ๋‹ค์šด๋ฐ›์•„ dataํด๋”์— ๋„ฃ์–ด์ฃผ์„ธ์š”." + ) + sys.exit(1) + + pdfmetrics.registerFont(TTFont("NanumGothic", font_path)) + + # PDF ํŒŒ์ผ๋ช… ์„ค์ • (์ž…๋ ฅ ํŒŒ์ผ๊ณผ ๋™์ผ ๊ฒฝ๋กœ ๋ฐ ์ด๋ฆ„์œผ๋กœ ์„ค์ •, ํ™•์žฅ์ž๋งŒ .pdf๋กœ ๋ณ€๊ฒฝ) + pdf_filename = os.path.splitext(args.csv_path)[0] + ".pdf" + + # PDF ์ƒ์„ฑ ํ•จ์ˆ˜ ํ˜ธ์ถœ + create_csat_style_pdf(df, pdf_filename) diff --git a/data_viz/labeling.py b/data_viz/labeling.py new file mode 100644 index 0000000..9c60556 --- /dev/null +++ b/data_viz/labeling.py @@ -0,0 +1,98 @@ +from ast import literal_eval + +import pandas as pd +import streamlit as st + + +def load_data(file_path): + data = pd.read_csv(file_path) + records = [] + for _, row in data.iterrows(): + problems = literal_eval(row["problems"]) + record = { + "id": row["id"], + "paragraph": row["paragraph"], + "question": problems["question"], + "choices": problems["choices"], + "answer": problems.get("answer", None), + "question_plus": problems.get("question_plus", None), + "target": problems.get("target", None), + "suggested_label": problems.get("suggested_label", None), + "is_label_issue": problems.get("is_label_issue", None), + } + records.append(record) + return data, records + + +def display_instance(record): + st.subheader("Paragraph") + st.write(record["paragraph"]) + + st.subheader("Question, Choices, Answer") + st.markdown("#### Question:") + st.write(record["question"]) + + st.markdown("#### Choices:") + for i, choice in enumerate(record["choices"], 1): + st.write(f"{i} : {choice}") + + st.markdown("#### Answer:") + st.write(str(record["answer"])) + + st.subheader("Question Plus") + st.write(str(record["question_plus"])) + + +def main(): + st.title("CSV ๋ฐ์ดํ„ฐ ์ธ์Šคํ„ด์Šค ๋ทฐ์–ด") + + data, records = load_data("../data/cleaned_output_with_labels_CL.csv") + + # 1055๋ฒˆ์งธ ์ธ๋ฑ์Šค๋ฅผ ๊ธฐ์ค€์œผ๋กœ ๋ฐ์ดํ„ฐ ๋ถ„ํ•  + split_index = 792 + before_split = data.iloc[:split_index] + after_split = data.iloc[split_index:] + + # 1055 ์ด์ „๊ณผ ์ดํ›„์˜ suggested_label ๊ฐœ์ˆ˜ ๊ณ„์‚ฐ + label_counts = { + "Before 1380": { + "Label 0": (before_split["suggested_label"] == 0).sum(), + "Not Label 0": (before_split["suggested_label"] != 0).sum(), + }, + "After 1380": { + "Label 1": (after_split["suggested_label"] == 1).sum(), + "Not Label 1": (after_split["suggested_label"] != 1).sum(), + }, + } + + # ๊ฒฐ๊ณผ๋ฅผ ๋ฐ์ดํ„ฐํ”„๋ ˆ์ž„์œผ๋กœ ๋ณ€ํ™˜ + label_counts_df = pd.DataFrame(label_counts) + + # ๋ผ๋ฒจ ๊ฐœ์ˆ˜ ์ถœ๋ ฅ + st.subheader("Suggested Label ๊ฐœ์ˆ˜") + st.write(label_counts_df) + + # Before Split์—์„œ suggested_label์ด 1์ธ ํ–‰ ํ•„ํ„ฐ๋ง + before_split_label_1 = before_split[before_split["suggested_label"] == 1] + + # After Split์—์„œ suggested_label์ด 1์ธ ํ–‰ ํ•„ํ„ฐ๋ง + after_split_label_1 = after_split[after_split["suggested_label"] == 0] + + # ๊ฒฐ๊ณผ ์ถœ๋ ฅ + st.subheader("Before 1380์—์„œ suggested_label์ด 1์ธ ์ธ์Šคํ„ด์Šค") + st.write(before_split_label_1) + + st.subheader("After 1380์—์„œ suggested_label์ด 0์ธ ์ธ์Šคํ„ด์Šค") + st.write(after_split_label_1) + + # ์ธ์Šคํ„ด์Šค ์„ ํƒ ๊ธฐ๋Šฅ ์ถ”๊ฐ€ + instance_index = st.number_input("์ธ์Šคํ„ด์Šค ์„ ํƒ", min_value=0, max_value=len(data) - 1, value=0, step=1) + + st.write(f"์„ ํƒ๋œ ์ธ์Šคํ„ด์Šค (์ธ๋ฑ์Šค {instance_index}):") + st.write(data.iloc[instance_index]) + + display_instance(records[instance_index]) + + +if __name__ == "__main__": + main() diff --git a/data_viz/streamlit_app.py b/data_viz/streamlit_app.py new file mode 100644 index 0000000..51153d9 --- /dev/null +++ b/data_viz/streamlit_app.py @@ -0,0 +1,71 @@ +from ast import literal_eval +import re + +import pandas as pd +import streamlit as st + + +def load_data(file_path): + data = pd.read_csv(file_path) + records = [] + for _, row in data.iterrows(): + problems = literal_eval(row["problems"]) + record = { + "id": row["id"], + "paragraph": row["paragraph"], + "question": problems["question"], + "choices": problems["choices"], + "answer": problems.get("answer", None), + "question_plus": problems.get("question_plus", None), + "documents": row.get("documents", None), + } + records.append(record) + return data, records + + +def display_instance(left, right, record): + with left: + st.subheader("Paragraph") + st.write(record["paragraph"]) + + st.subheader("Question, Choices, Answer") + st.markdown("#### Question:") + st.write(record["question"]) + + st.markdown("#### Choices:") + for i, choice in enumerate(record["choices"], 1): + if i == record["answer"]: + st.markdown(f"{i} : {choice}", unsafe_allow_html=True) + else: + st.write(f"{i} : {choice}") + + st.markdown("#### Answer:") + st.write(str(record["answer"])) + + st.subheader("Question Plus") + st.write(str(record["question_plus"])) + with right: + st.subheader("Documents") + result = re.split(r"(?=\[)", str(record["documents"])) + for part in result: + st.write(part) + + +def main(file_path="../data/train.csv"): + st.set_page_config(layout="wide") + st.title("CSV ๋ฐ์ดํ„ฐ ์ธ์Šคํ„ด์Šค ๋ทฐ์–ด") + data, records = load_data(file_path) + + left, right = st.columns([0.5, 0.5]) + with left: + instance_index = st.number_input("์ธ์Šคํ„ด์Šค ์„ ํƒ", min_value=0, max_value=len(data) - 1, value=0, step=1) + + display_instance(left, right, records[instance_index]) + + with left: + st.write(f"์„ ํƒ๋œ ์ธ์Šคํ„ด์Šค (์ธ๋ฑ์Šค {instance_index}):") + st.write(data.iloc[instance_index]) + + +if __name__ == "__main__": + main("../data/train_retrieve.csv") diff --git a/ensemble/hard_voting.py b/ensemble/hard_voting.py new file mode 100644 index 0000000..8965761 --- /dev/null +++ b/ensemble/hard_voting.py @@ -0,0 +1,90 @@ +from collections import Counter +import csv +from glob import glob + + +""" +# ํ•˜๋“œ ๋ณดํŒ… ์•™์ƒ๋ธ” ์‚ฌ์šฉ ๋ฐฉ๋ฒ• (CSV ๋ฒ„์ „) + +1. ํŒŒ์ผ ์ค€๋น„: + - 'ensemble/results_hard' ํด๋” ์•ˆ์— ์•™์ƒ๋ธ”ํ•˜๊ณ  ์‹ถ์€ ๋ชจ๋“  CSV ํŒŒ์ผ๋“ค์„ ๋„ฃ์Šต๋‹ˆ๋‹ค. + +2. ์šฐ์„ ์ˆœ์œ„ ์„ค์ •: + - 'priority_order' ๋ฆฌ์ŠคํŠธ์— ๋ชจ๋ธ์˜ ์šฐ์„ ์ˆœ์œ„๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค. + - ์˜ˆ: priority_order = ['predictions1.csv', 'predictions2.csv', 'predictions3.csv'] + - ๋ฆฌ์ŠคํŠธ์˜ ์•ž์ชฝ์— ์žˆ๋Š” ๋ชจ๋ธ์ผ์ˆ˜๋ก ๋†’์€ ์šฐ์„ ์ˆœ์œ„๋ฅผ ๊ฐ€์ง‘๋‹ˆ๋‹ค. + +3. ์ฝ”๋“œ ์‹คํ–‰: + - ์„ค์ •์„ ๋งˆ์นœ ํ›„ ์ฝ”๋“œ๋ฅผ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค. + - ์ฝ”๋“œ๋Š” ์ž๋™์œผ๋กœ ํด๋” ๋‚ด์˜ ๋ชจ๋“  CSV ํŒŒ์ผ์„ ์ฝ์–ด ์•™์ƒ๋ธ”์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค. + +4. ํ•˜๋“œ ๋ณดํŒ… ๊ณผ์ •: + - ๊ฐ ์งˆ๋ฌธ์— ๋Œ€ํ•ด ๋ชจ๋“  ๋ชจ๋ธ์˜ ๋‹ต๋ณ€์„ ์ˆ˜์ง‘ํ•ฉ๋‹ˆ๋‹ค. + - ๊ฐ€์žฅ ๋งŽ์ด ๋‚˜์˜จ ๋‹ต๋ณ€(๋“ค)์„ ์„ ํƒํ•ฉ๋‹ˆ๋‹ค. + - ๋™์ ์ธ ๊ฒฝ์šฐ, ์šฐ์„ ์ˆœ์œ„๊ฐ€ ๊ฐ€์žฅ ๋†’์€ ๋ชจ๋ธ์˜ ๋‹ต๋ณ€์„ ์„ ํƒํ•ฉ๋‹ˆ๋‹ค. + +5. ๊ฒฐ๊ณผ ํ™•์ธ: + - ์•™์ƒ๋ธ” ๊ฒฐ๊ณผ๋Š” 'final_hard_predictions.csv' ํŒŒ์ผ๋กœ ์ €์žฅ๋ฉ๋‹ˆ๋‹ค. + - ์ด ํŒŒ์ผ์—๋Š” ๊ฐ ์งˆ๋ฌธ์— ๋Œ€ํ•œ ์ตœ์ข… ๋‹ต๋ณ€์ด ํฌํ•จ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค. + +์ฃผ์˜: ๋ชจ๋ธ์˜ ์šฐ์„ ์ˆœ์œ„๋Š” ๊ฐ ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์ด๋‚˜ ํŠน์„ฑ์„ ๊ณ ๋ คํ•˜์—ฌ ์‹ ์ค‘ํžˆ ๊ฒฐ์ •ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. +์šฐ์„ ์ˆœ์œ„ ์„ค์ •์— ๋”ฐ๋ผ ์ตœ์ข… ๊ฒฐ๊ณผ๊ฐ€ ํฌ๊ฒŒ ๋‹ฌ๋ผ์งˆ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. +""" + +# ์šฐ์„ ์ˆœ์œ„ ์ง์ ‘ ์ •์˜ +priority_order = ["output (1).csv", "output (7).csv", "output (8).csv"] + + +def hard_voting_with_priority(predictions, priority_order): + result = {} + for id in predictions[0].keys(): + answers = [pred[id] for pred in predictions if id in pred] + answer_counts = Counter(answer for answer in answers if answer) + + if answer_counts: + max_count = max(answer_counts.values()) + top_answers = [ans for ans, count in answer_counts.items() if count == max_count] + + if len(top_answers) == 1: + result[id] = top_answers[0] + else: + for model in priority_order: + model_index = next((i for i, pred in enumerate(predictions) if pred.get("filename") == model), None) + if model_index is not None: + model_answer = predictions[model_index].get(id, "") + if model_answer in top_answers: + result[id] = model_answer + break + else: + result[id] = top_answers[0] + else: + result[id] = "" + return result + + +# ์˜ˆ์ธก ํŒŒ์ผ๋“ค์„ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค. +prediction_files = glob("./results_hard/*.csv") +predictions = [] + +# ๊ฐ prediction ํŒŒ์ผ์„ ์ฝ์–ด์™€์„œ predictions ๋ฆฌ์ŠคํŠธ์— ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค. +for file_name in prediction_files: + prediction = {} + with open(file_name, "r", encoding="utf-8") as file: + csv_reader = csv.reader(file) + next(csv_reader) # ํ—ค๋” ํ–‰์„ ๊ฑด๋„ˆ๋œ๋‹ˆ๋‹ค + for row in csv_reader: + prediction[row[0]] = row[1] # ์ฒซ ๋ฒˆ์งธ ์—ด์„ ํ‚ค๋กœ, ๋‘ ๋ฒˆ์งธ ์—ด์„ ๊ฐ’์œผ๋กœ ์‚ฌ์šฉ + # Remove filename key from prediction dictionary + predictions.append(prediction) + +# ํ•˜๋“œ ๋ณดํŒ…์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค. +final_predictions = hard_voting_with_priority(predictions, priority_order) + +# ๊ฒฐ๊ณผ๋ฅผ CSV ํŒŒ์ผ๋กœ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค. +with open("final_hard_predictions.csv", "w", newline="", encoding="utf-8") as f: + writer = csv.writer(f) + writer.writerow(["id", "answer"]) # ํ—ค๋” ์ž‘์„ฑ + for id, answer in final_predictions.items(): + writer.writerow([id, answer]) + +print("์•™์ƒ๋ธ” ๊ฒฐ๊ณผ๊ฐ€ 'final_hard_predictions.csv' ํŒŒ์ผ๋กœ ์ €์žฅ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.") diff --git a/.github/.keep b/ensemble/results_hard/.gitkeep similarity index 100% rename from .github/.keep rename to ensemble/results_hard/.gitkeep diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..db9b2f8 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,48 @@ +[tool.ruff] +line-length = 120 + +# Exclude the following files and directories. +exclude = [ + ".git", + ".hg", + ".mypy_cache", + ".tox", + ".venv", + "_build", + "buck-out", + "build", + "dist", + "env", + "venv", + "**/*.ipynb", # Jupyter Notebook ํŒŒ์ผ ์ œ์™ธ +] + +[tool.ruff.lint] +# Never enforce `E501` (line length violations). +extend-select = ["C901", "E501", "E402"] +select = ["C", "E", "F", "I", "W"] + + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["E402", "F401", "F403", "F811"] + +[tool.ruff.lint.isort] +lines-after-imports = 2 + +# Setting the order of sections +section-order = ["standard-library", "third-party", "local-folder"] +combine-as-imports = true +force-sort-within-sections = true + +[tool.ruff.format] +# Like Black, use double quotes for strings. +quote-style = "double" + +# Like Black, indent with spaces, rather than tabs. +indent-style = "space" + +# Like Black, respect magic trailing commas. +skip-magic-trailing-comma = false + +# Like Black, automatically detect the appropriate line ending. +line-ending = "auto" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..49cc160 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,51 @@ +# CUDA Version: 12.2 +# Ubuntu 20.04.6 +# python 3.10.13 + +# Deep Learning +auto_gptq==0.7.1 +bitsandbytes==0.44.1 +evaluate==0.4.3 +huggingface-hub==0.26.2 +numpy==2.0.0 +optimum==1.23.3 +peft==0.5.0 +scikit-learn==1.5.2 +torch==2.5.1 # 2.5.1+cu124 +tqdm==4.67.0 +transformers==4.46.2 +trl==0.12.0 +wandb==0.18.5 + +# RAG +elasticsearch==8.16.0 +konlpy==0.6.0 +rank-bm25==0.2.2 +wikiextractor==3.0.6 +faiss-cpu==1.9.0 # faiss-gpu==1.7.2 + +# Utils +beautifulsoup4==4.12.3 +ipykernel==6.29.5 +ipywidgets==8.1.5 +loguru==0.7.2 +matplotlib==3.9.2 +python-dotenv==1.0.1 +reportlab==4.2.5 +streamlit==1.40.1 +pdfminer.six==20240706 + +# Google Drive API +google-api-python-client==2.151.0 +google-auth-httplib2==0.2.0 +google-auth-oauthlib==1.2.1 + +# Automatically installed dependencies +# pandas==2.2.3 +# pyarrow==18.0.0 +# datasets==3.1.0 +# safetensors==0.4.5 +# scipy==1.14.1 +# tqdm==4.67.0 +# PyYAML==6.0.2 +# requests==2.32.3 diff --git a/script/run_script.bash b/script/run_script.bash new file mode 100755 index 0000000..c564b27 --- /dev/null +++ b/script/run_script.bash @@ -0,0 +1,11 @@ +#!/bin/bash + +# ์ฒซ ๋ฒˆ์งธ ์‹คํ—˜ +nohup python -u main.py --config config-normal.yaml & +PYTHON_PID_1=$! +wait $PYTHON_PID_1 + +# ๋‘ ๋ฒˆ์งธ ์‹คํ—˜ +nohup python -u main.py --config config-rag.yaml & +PYTHON_PID_2=$! +wait $PYTHON_PID_2 diff --git a/script/run_with_gpu_monitoring.bash b/script/run_with_gpu_monitoring.bash new file mode 100755 index 0000000..fecac05 --- /dev/null +++ b/script/run_with_gpu_monitoring.bash @@ -0,0 +1,31 @@ +#!/bin/bash + +# adamw_torch ์„ค์ • +nvidia-smi --query-gpu=timestamp,name,utilization.gpu,memory.used,memory.free --format=csv -l 1 > ../log/gpu_log_adamw_torch.csv & +NVIDIA_LOG_PID_1=$! +nohup python -u main.py --config config-adamw_torch.yaml & +PYTHON_PID_1=$! +wait $PYTHON_PID_1 + +# ์ฒซ ๋ฒˆ์งธ ๋ชจ๋‹ˆํ„ฐ๋ง ํ”„๋กœ์„ธ์Šค ์ข…๋ฃŒ +kill $NVIDIA_LOG_PID_1 + +# adafactor ์„ค์ • +nvidia-smi --query-gpu=timestamp,name,utilization.gpu,memory.used,memory.free --format=csv -l 1 > ../log/gpu_log_adafactor.csv & +NVIDIA_LOG_PID_2=$! +nohup python -u main.py --config config-adafactor.yaml & +PYTHON_PID_2=$! +wait $PYTHON_PID_2 + +# ๋‘ ๋ฒˆ์งธ ๋ชจ๋‹ˆํ„ฐ๋ง ํ”„๋กœ์„ธ์Šค ์ข…๋ฃŒ +kill $NVIDIA_LOG_PID_2 + +# adamw_bnb_8bit ์„ค์ • +nvidia-smi --query-gpu=timestamp,name,utilization.gpu,memory.used,memory.free --format=csv -l 1 > ../log/gpu_log_adamw_bnb_8bit.csv & +NVIDIA_LOG_PID_3=$! +nohup python -u main.py --config config-adamw_bnb_8bit.yaml & +PYTHON_PID_3=$! +wait $PYTHON_PID_3 + +# ์„ธ ๋ฒˆ์งธ ๋ชจ๋‹ˆํ„ฐ๋ง ํ”„๋กœ์„ธ์Šค ์ข…๋ฃŒ +kill $NVIDIA_LOG_PID_3 diff --git a/script/setup-gpu-server.bash b/script/setup-gpu-server.bash new file mode 100755 index 0000000..e73dcb0 --- /dev/null +++ b/script/setup-gpu-server.bash @@ -0,0 +1,111 @@ +#!/bin/bash + +########################################## +# GPU ์„œ๋ฒ„ ์ธ์Šคํ„ด์Šค ์ƒ์„ฑ ์‹œ ํ•„์š”ํ•œ ๊ฐœ๋ฐœ ํ™˜๊ฒฝ ์„ธํŒ… +# conda ๋ฏธ์„ค์น˜ ํ™˜๊ฒฝ์—์„œ๋Š” conda ์„ค์น˜ ๊ณผ์ •์„ ์ถ”๊ฐ€ +# ์œ ์ €๋ช… / ๋””๋ ‰ํ† ๋ฆฌ / ๊ถŒํ•œ ์„ค์ • ๋“ฑ ์ˆ˜์ •ํ•˜์—ฌ ์‚ฌ์šฉ +########################################## + +##################### Install ##################### +apt-get update +apt-get install -y sudo +sudo apt-get install -y wget git vim build-essential + +##################### Set root password ##################### +echo "root:root" | chpasswd + +##################### conda ##################### +export PATH="/opt/conda/bin:$PATH" +conda init bash +conda config --set auto_activate_base false +source ~/.bashrc +conda create -n main python=3.10.13 -y +sudo chmod -R 777 /opt/conda/env + +##################### Users: dir & permission ##################### +users=("camper") + +for i in "${!users[@]}"; do + user="${users[$i]}" + user_folder="/data/ephemeral/home/$user" + + # Create user with custom home directory and give sudo privileges + sudo mkdir -p $user_folder + sudo chmod 777 $user_folder + sudo adduser --disabled-password --home $user_folder --gecos "" $user + # Set user password same as username + echo "${user}:${user}" | sudo chpasswd + sudo chsh -s /bin/bash $user + echo "$user ALL=(ALL) NOPASSWD:ALL" | sudo tee /etc/sudoers.d/$user + +done + +##################### Users: conda ##################### +for user in "${users[@]}"; do + user_folder="/data/ephemeral/home/$user" + + # Add conda to each user's PATH and initialize conda + su - $user bash -c 'export PATH="/opt/conda/bin:$PATH"; conda init bash; conda config --set auto_activate_base false; source ~/.bashrc;' + echo "cd $user_folder" | sudo tee -a $user_folder/.bashrc + echo 'conda activate main' | sudo tee -a $user_folder/.bashrc + + # Add local bin path to each user's .bashrc + echo "export PATH=\$PATH:/data/ephemeral/home/$user/.local/bin" | sudo tee -a $user_folder/.bashrc + + sudo chmod -R 777 $user_folder + sudo chown -R $user:$user $user_folder + +done + +##################### Git ##################### +users=("sujin" "seongmin" "sungjae" "gayeon" "yeseo" "minseo") +BASE_DIR="/data/ephemeral/home/camper" + +# ๊ฐ ์‚ฌ์šฉ์ž๋ณ„ ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ +for user in "${users[@]}"; do + mkdir -p "$BASE_DIR/$user" +done + +# ๊ธ€๋กœ๋ฒŒ .gitconfig ์ƒ์„ฑ +cat << EOF > "$BASE_DIR/.gitconfig" +[user] + name = Camper User + email = camper@example.com + +# ์‚ฌ์šฉ์ž๋ณ„ ํด๋” ์„ค์ • ํฌํ•จ +EOF + +# includeIf ์„ค์ •์„ ๋™์ ์œผ๋กœ ์ถ”๊ฐ€ +for user in "${users[@]}"; do + cat << EOF >> "$BASE_DIR/.gitconfig" +[includeIf "gitdir:$BASE_DIR/$user/"] + path = $BASE_DIR/$user/.gitconfig +EOF +done + +# ๊ฐ ์‚ฌ์šฉ์ž ํด๋”์— .gitconfig ์ƒ์„ฑ +for user in "${users[@]}"; do + cat << EOF > "$BASE_DIR/$user/.gitconfig" +[user] + name = $user + email = $user@example.com +EOF +done + +# ๊ถŒํ•œ ์„ค์ • +chown -R camper:camper "$BASE_DIR" +chmod -R 755 "$BASE_DIR" + +echo "Git configuration setup completed!" + +echo "Setup complete!" + + + +##### git์€ ๊ฐ์ž ํด๋”์—์„œ ์„ธํŒ… ##### +# git clone https://"$token"@github.com/boostcampaitech7/level2-nlp-generationfornlp-nlp-02-lv3.git + +# git config --local user.email "$email" +# git config --local user.name "$username" +# git config --local credential.helper "cache --timeout=360000" +# git config --local commit.template .gitmessage.txt diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..06354ac --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[tool:pytest] +addopts = -ra -v -l diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_sample.py b/tests/test_sample.py new file mode 100644 index 0000000..d2b4018 --- /dev/null +++ b/tests/test_sample.py @@ -0,0 +1,2 @@ +def test_add(): + assert 1 + 2 == 3