diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a20f769 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +data/* +database/* +models/* +*.pyc diff --git a/NatSQL/natsql2sql/preprocess/TokenString.py b/NatSQL/natsql2sql/preprocess/TokenString.py index be86aeb..7cf25ce 100644 --- a/NatSQL/natsql2sql/preprocess/TokenString.py +++ b/NatSQL/natsql2sql/preprocess/TokenString.py @@ -245,21 +245,21 @@ def get_spacy_tokenizer(): return global_tokenizer import spacy - from spacy.symbols import ORTH, LEMMA + from spacy.symbols import ORTH, NORM nlp = spacy.load("en_core_web_sm") import re from spacy.tokenizer import Tokenizer - suffixes = nlp.Defaults.suffixes + (r'((\d{4}((_|-|/){1}\d{2}){2})|((\d{2})(_|-|/)){2}\d{4})(\s\d{2}(:\d{2}){2}){0,1}',) + (r'(\d{1,2}(st|nd|rd|th){0,1}(,|\s)){0,1}((J|j)an(uary){0,1}|(F|f)eb(ruary){0,1}|(M|m)ar(ch){0,1}|(A|a)pr(il){0,1}|(M|m)ay|(J|j)un(e){0,1}|(J|j)ul(y){0,1}|(A|a)ug(ust){0,1}|(S|s)ep(tember){0,1}|(O|o)ct(ober){0,1}|(N|n)ov(ember){0,1}|(D|d)ec(ember){0,1})(\s|,)(\d{1,2}(st|nd|rd|th){0,1}(\s|,){1,3}){0,1}\d{4}',) + ( r'(\d{1,6}(_|-|\+|/)\d{0,6}[A-Za-z]{0,6}\d{0,6}[A-Za-z]{0,6})',) + suffixes = nlp.Defaults.suffixes + list((r'((\d{4}((_|-|/){1}\d{2}){2})|((\d{2})(_|-|/)){2}\d{4})(\s\d{2}(:\d{2}){2}){0,1}',) + (r'(\d{1,2}(st|nd|rd|th){0,1}(,|\s)){0,1}((J|j)an(uary){0,1}|(F|f)eb(ruary){0,1}|(M|m)ar(ch){0,1}|(A|a)pr(il){0,1}|(M|m)ay|(J|j)un(e){0,1}|(J|j)ul(y){0,1}|(A|a)ug(ust){0,1}|(S|s)ep(tember){0,1}|(O|o)ct(ober){0,1}|(N|n)ov(ember){0,1}|(D|d)ec(ember){0,1})(\s|,)(\d{1,2}(st|nd|rd|th){0,1}(\s|,){1,3}){0,1}\d{4}',) + ( r'(\d{1,6}(_|-|\+|/)\d{0,6}[A-Za-z]{0,6}\d{0,6}[A-Za-z]{0,6})',)) suffix_regex = spacy.util.compile_suffix_regex(suffixes) nlp.tokenizer.suffix_search = suffix_regex.search - nlp.tokenizer.add_special_case(u'Ph.D', [{ORTH: u'Ph.D', LEMMA: u'ph.d'}]) - nlp.tokenizer.add_special_case(u'id', [{ORTH: u'id', LEMMA: u'id'}]) - nlp.tokenizer.add_special_case(u'Id', [{ORTH: u'Id', LEMMA: u'id'}]) - nlp.tokenizer.add_special_case(u'ID', [{ORTH: u'ID', LEMMA: u'id'}]) - nlp.tokenizer.add_special_case(u'iD', [{ORTH: u'iD', LEMMA: u'id'}]) - nlp.tokenizer.add_special_case(u'statuses', [{ORTH: u'statuses', LEMMA: u'status'}]) + nlp.tokenizer.add_special_case(u'Ph.D', [{ORTH: u'Ph.D', NORM: u'ph.d'}]) + nlp.tokenizer.add_special_case(u'id', [{ORTH: u'id', NORM: u'id'}]) + nlp.tokenizer.add_special_case(u'Id', [{ORTH: u'Id', NORM: u'id'}]) + nlp.tokenizer.add_special_case(u'ID', [{ORTH: u'ID', NORM: u'id'}]) + nlp.tokenizer.add_special_case(u'iD', [{ORTH: u'iD', NORM: u'id'}]) + nlp.tokenizer.add_special_case(u'statuses', [{ORTH: u'statuses', NORM: u'status'}]) global_tokenizer = Tokenizer_Similar_Allennlp(nlp) global_spacy = nlp diff --git a/README.md b/README.md index 26a68a5..7643ab5 100644 --- a/README.md +++ b/README.md @@ -5,9 +5,11 @@
# RESDSQL: Decoupling Schema Linking and Skeleton Parsing for Text-to-SQL + This is the official implementation of the paper "RESDSQL: Decoupling Schema Linking and Skeleton Parsing for Text-to-SQL" (AAAI 2023). If this repository could help you, please cite the following paper: + ``` @inproceedings{li2022resdsql, author = {Haoyang Li and Jing Zhang and Cuiping Li and Hong Chen}, @@ -20,82 +22,92 @@ If this repository could help you, please cite the following paper: `Update (2023.3.13):` We evaluated our method on a diagnostic evaluation benchmark, [Dr.Spider](https://github.com/awslabs/diagnostic-robustness-text-to-sql), which contains 17 test sets to measure the robustness of Text-to-SQL parsers under different perturbation perspectives. `Update (2023.5.19):` We added support for [CSpider](https://taolusi.github.io/CSpider-explorer/), a Chinese Text-to-SQL benchmark with Chinese questions, English database schema, and corresponding SQL queries. - + ## Overview + We introduce a new Text-to-SQL parser, **RESDSQL** (**R**anking-enhanced **E**ncoding plus a **S**keleton-aware **D**ecoding framework for Text-to-**SQL**), which attempts to decoulpe the schema linking and the skeleton parsing to reduce the difficuty of Text-to-SQL. More details can be found in our [paper](https://arxiv.org/abs/2302.05965). All experiments are conducted on a single NVIDIA A100 80G GPU. ## Evaluation Results + We evaluate RESDSQL on six benchmarks: Spider, Spider-DK, Spider-Syn, Spider-Realistic, Dr.Spider, and CSpider. We adopt two metrics: Exact-set-Match accuracy (EM) and EXecution accuracy (EX). Let's look at the following numbers: **On Spider:** -| Model | Dev EM | Dev EX | Test EM | Test EX | -|-------|--------|--------|---------|---------| -| RESDSQL-3B+NatSQL | **80.5%** | **84.1%** | **72.0%** | **79.9%** | -| RESDSQL-3B | 78.0% | 81.8% | - | - | -| RESDSQL-Large+NatSQL | 76.7% | 81.9% | - | - | -| RESDSQL-Large | 75.8% | 80.1% | - | - | -| RESDSQL-Base+NatSQL | 74.1% | 80.2% | - | - | -| RESDSQL-Base | 71.7% | 77.9% | - | - | +| Model | Dev EM | Dev EX | Test EM | Test EX | +| -------------------- | --------------- | --------------- | --------------- | --------------- | +| RESDSQL-3B+NatSQL | **80.5%** | **84.1%** | **72.0%** | **79.9%** | +| RESDSQL-3B | 78.0% | 81.8% | - | - | +| RESDSQL-Large+NatSQL | 76.7% | 81.9% | - | - | +| RESDSQL-Large | 75.8% | 80.1% | - | - | +| RESDSQL-Base+NatSQL | 74.1% | 80.2% | - | - | +| RESDSQL-Base | 71.7% | 77.9% | - | - | **On Spider-DK, Spider-Syn, and Spider-Realistic:** -| Model | DK EM | DK EX | Syn EM | Syn EX | Realistic EM | Realistic EX | -|-------|-------|-------|--------|--------|--------------|--------------| -| RESDSQL-3B+NatSQL| 53.3% | 66.0% | 69.1% | 76.9% | 77.4% | 81.9% | +| Model | DK EM | DK EX | Syn EM | Syn EX | Realistic EM | Realistic EX | +| ----------------- | ----- | ----- | ------ | ------ | ------------ | ------------ | +| RESDSQL-3B+NatSQL | 53.3% | 66.0% | 69.1% | 76.9% | 77.4% | 81.9% | **On Dr.Spider's perturbation sets:** Following Dr.Spider, we only report **EX** for each post-perturbation set and choose PICARD and CodeX as our baseline methods. -| Perturbation set | PICARD | CodeX | RESDSQL-3B | RESDSQL-3B+NatSQL | -|------------------|--------|-------|-------------------|-----| -| DB-Schema-synonym | 56.5% | 62.0% | 63.3% | **68.3%** | -| DB-Schema-abbreviation | 64.7% | 68.6% | 64.5% | **70.0%** | -| DB-DBcontent-equivalence | 43.7% | **51.6%** | 40.3% | 40.1% | -| NLQ-Keyword-synonym | 66.3% | 55.5% | 67.5% | **72.4%** | -| NLQ-Keyword-carrier | 82.7% | 85.2% | **86.7%** | 83.5% | -| NLQ-Column-synonym | 57.2% | 54.7% | 57.4% | **63.1%** | -| NLQ-Column-carrier | 64.9% | 51.1% | **69.9%** | 63.9% | -| NLQ-Column-attribute | 56.3% | 46.2% | 58.8% | **71.4%** | -| NLQ-Column-value | 69.4% | 71.4% | 73.4% | **76.6%** | -| NLQ-Value-synonym | 53.0% | **59.9%** | 53.8% | 53.2% | -| NLQ-Multitype | 57.1% | 53.7% | 60.1% | **60.7%** | -| NLQ-Others | 78.3% | 69.7% | 77.3% | **79.0%** | -| SQL-Comparison | 68.0% | 66.9% | 70.2% | **82.0%** | -| SQL-Sort-order | 74.5% | 57.8% | 79.7% | **85.4%** | -| SQL-NonDB-number | 77.1% | **89.3%** | 83.2% | 85.5% | -| SQL-DB-text | 65.1% | 72.4% | 67.8% | **74.3%** | -| SQL-DB-number | 85.1% | 79.3% | 85.4% | **88.8%** | -| Average | 65.9% | 64.4% | 68.2% | **71.7%** | +| Perturbation set | PICARD | CodeX | RESDSQL-3B | RESDSQL-3B+NatSQL | +| ------------------------ | ------ | --------------- | --------------- | ----------------- | +| DB-Schema-synonym | 56.5% | 62.0% | 63.3% | **68.3%** | +| DB-Schema-abbreviation | 64.7% | 68.6% | 64.5% | **70.0%** | +| DB-DBcontent-equivalence | 43.7% | **51.6%** | 40.3% | 40.1% | +| NLQ-Keyword-synonym | 66.3% | 55.5% | 67.5% | **72.4%** | +| NLQ-Keyword-carrier | 82.7% | 85.2% | **86.7%** | 83.5% | +| NLQ-Column-synonym | 57.2% | 54.7% | 57.4% | **63.1%** | +| NLQ-Column-carrier | 64.9% | 51.1% | **69.9%** | 63.9% | +| NLQ-Column-attribute | 56.3% | 46.2% | 58.8% | **71.4%** | +| NLQ-Column-value | 69.4% | 71.4% | 73.4% | **76.6%** | +| NLQ-Value-synonym | 53.0% | **59.9%** | 53.8% | 53.2% | +| NLQ-Multitype | 57.1% | 53.7% | 60.1% | **60.7%** | +| NLQ-Others | 78.3% | 69.7% | 77.3% | **79.0%** | +| SQL-Comparison | 68.0% | 66.9% | 70.2% | **82.0%** | +| SQL-Sort-order | 74.5% | 57.8% | 79.7% | **85.4%** | +| SQL-NonDB-number | 77.1% | **89.3%** | 83.2% | 85.5% | +| SQL-DB-text | 65.1% | 72.4% | 67.8% | **74.3%** | +| SQL-DB-number | 85.1% | 79.3% | 85.4% | **88.8%** | +| Average | 65.9% | 64.4% | 68.2% | **71.7%** | Notice: We also employed the modified test suite script (see this [issue](https://github.com/awslabs/diagnostic-robustness-text-to-sql/issues/1)) to evaluate the model-generated results, but obtained the same numbers as above. Nevertheless, we suggest that further work should use their modified script to evaluate Dr.Spider. **On CSpider's development set:** -| Model | EM | EXEC | -| ----- | -- | ---- | -| RESDSQL-3B+NatSQL | **66.3%** | **81.1%** | -| RESDSQL-Large+NatSQL | 64.3% | 81.1% | -| LGESQL + GTL + Electra + QT | 64.0% | - | -| LGESQL + ELECTRA + QT | 64.5% | - | -| RESDSQL-Base+NatSQL | 61.7% | 78.1% | +| Model | EM | EXEC | +| --------------------------- | --------------- | --------------- | +| RESDSQL-3B+NatSQL | **66.3%** | **81.1%** | +| RESDSQL-Large+NatSQL | 64.3% | 81.1% | +| LGESQL + GTL + Electra + QT | 64.0% | - | +| LGESQL + ELECTRA + QT | 64.5% | - | +| RESDSQL-Base+NatSQL | 61.7% | 78.1% | ## Prerequisites + Create a virtual anaconda environment: + ```sh conda create -n your_env_name python=3.8.5 ``` + Active it and install the cuda version Pytorch: + ```sh -conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch +conda install pytorch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 cudatoolkit=11.3 -c pytorch ``` + Install other required modules and tools: + ```sh pip install -r requirements.txt -pip install https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.2.0/en_core_web_sm-2.2.0.tar.gz +python -m spacy download en_core_web_sm python nltk_downloader.py ``` + Create several folders: + ```sh mkdir eval_results mkdir models @@ -103,7 +115,9 @@ mkdir tensorboard_log mkdir third_party mkdir predictions ``` + Clone evaluation scripts: + ```sh cd third_party git clone https://github.com/ElementAI/spider.git @@ -113,61 +127,76 @@ cd .. ``` ## Prepare data + Download [data](https://drive.google.com/file/d/19tsgBGAxpagULSl9r85IFKIZb4kyBGGu/view?usp=sharing) **(including Spider, Spider-DK, Spider-Syn, Spider-Realistic, Dr.Spider, and CSpider)** and [database](https://drive.google.com/file/d/1s4ItreFlTa8rUdzwVRmUR2Q9AHnxbNjo/view?usp=share_link) and then unzip them: + ```sh unzip data.zip unzip database.zip ``` + Notice: Dr.Spider has been preprocessed following the instructions on its Github page. ## Inference + All evaluation results can be easily reproduced through our released scripts and checkpionts. + ### Step1: Prepare Checkpoints + Because RESDSQL is a two-stage algorithm, therefore, you should first download cross-encoder checkpoints. Here are links: -| Cross-encoder Checkpoints | Google Drive | Baidu Netdisk | -|----------|-----------|--------------| -| text2natsql_schema_item_classifier| [Link](https://drive.google.com/file/d/1UWNj1ZADfKa1G5I4gBYCJeEQO6piMg4G/view?usp=share_link) | [Link](https://pan.baidu.com/s/15eGPMSx7K8oLV4hkjnCzaw) (pwd: 18w8) | -| text2sql_schema_item_classifier | [Link](https://drive.google.com/file/d/1zHAhECq1uGPR9Rt1EDsTai1LbRx0jYIo/view?usp=share_link) | [Link](https://pan.baidu.com/s/1trSi8OBOcPo5NkZb_o-T4g) (pwd: dr62) | -| xlm_roberta_text2natsql_schema_item_classifier (trained on CSpider) | - | [Link](https://pan.baidu.com/s/1oTkuqoU-RWr3QsNC9Y-RpA) (pwd: 3sdu) | + +| Cross-encoder Checkpoints | Google Drive | Baidu Netdisk | +| ------------------------------------------------------------------- | ------------------------------------------------------------------------------------------ | ---------------------------------------------------------------- | +| text2natsql_schema_item_classifier | [Link](https://drive.google.com/file/d/1UWNj1ZADfKa1G5I4gBYCJeEQO6piMg4G/view?usp=share_link) | [Link](https://pan.baidu.com/s/15eGPMSx7K8oLV4hkjnCzaw) (pwd: 18w8) | +| text2sql_schema_item_classifier | [Link](https://drive.google.com/file/d/1zHAhECq1uGPR9Rt1EDsTai1LbRx0jYIo/view?usp=share_link) | [Link](https://pan.baidu.com/s/1trSi8OBOcPo5NkZb_o-T4g) (pwd: dr62) | +| xlm_roberta_text2natsql_schema_item_classifier (trained on CSpider) | - | [Link](https://pan.baidu.com/s/1oTkuqoU-RWr3QsNC9Y-RpA) (pwd: 3sdu) | Then, you should download T5 (for Spider) or mT5 (for CSpider) checkpoints: -| T5/mT5 Checkpoints | Google Drive/OneDrive | Baidu Netdisk | -|-------|-------|-------| -| text2natsql-t5-3b | [OneDrive link](https://1drv.ms/u/s!Ak05bBUBFYiktcdziiE79xaeKtO6qg?e=e9424n) | [Link](https://pan.baidu.com/s/1ReHso0QgX5aT-hGySUnlrQ) (pwd: 4r98) | -| text2sql-t5-3b | [Google Drive link](https://drive.google.com/file/d/1M-zVeB6TKrvcIzaH8vHBIKeWqPn95i11/view?usp=sharing) | [Link](https://pan.baidu.com/s/1mZxakfes4wRSEwnRW43i5A) (pwd: sc62) | -| text2natsql-t5-large | [Google Drive link](https://drive.google.com/file/d/1ZwFsH24_qKC3xwYdedPi6T_8argguWHe/view?usp=sharing) | [Link](https://pan.baidu.com/s/18H8lgnv9gfXmUo_oO_CdOA) (pwd: 7iyq) | -| text2sql-t5-large | [Google Drive link](https://drive.google.com/file/d/1-xwtKwfJZSrmJrU-_Xdkx1kPuZao7r7e/view?usp=sharing) | [Link](https://pan.baidu.com/s/1Mwg0OZZ48APEq9jPvQQNtw) (pwd: q58k) | -| text2natsql-t5-base | [Google Drive link](https://drive.google.com/file/d/1QyfSfHHrxfIM5X9gKUYNr_0ZRVvb1suV/view?usp=share_link) | [Link](https://pan.baidu.com/s/1XegaZFvXuZ_jf3P-9YPQCQ) (pwd: pyxf) | -| text2sql-t5-base | [Google Drive link](https://drive.google.com/file/d/1lqZ81f_fSZtg6BRcRw1-Ol-RJCcKRsmH/view?usp=sharing) | [Link](https://pan.baidu.com/s/1-6H7zStq0WCJHTjDuVspoQ) (pwd: wuek) | -| text2natsql-mt5-xl-cspider (trained on CSpider) | - | [Link](https://pan.baidu.com/s/1tFkGOiw5ETB83-Ct3MuVXA) (pwd: y7ei) | -| text2natsql-mt5-large-cspider (trained on CSpider) | - | [Link](https://pan.baidu.com/s/1LUjL-2nwNfUJhzI3cm7aEQ) (pwd: ydqk) | -| text2natsql-mt5-base-cspider (trained on CSpider) | - | [Link](https://pan.baidu.com/s/1tbEUIBPUTA2Oz7K2lHT9oA) (pwd: d8b8) | - -The checkpoints should be placed in the `models` folder. + +| T5/mT5 Checkpoints | Google Drive/OneDrive | Baidu Netdisk | +| -------------------------------------------------- | ------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------- | +| text2natsql-t5-3b | [OneDrive link](https://1drv.ms/u/s!Ak05bBUBFYiktcdziiE79xaeKtO6qg?e=e9424n) | [Link](https://pan.baidu.com/s/1ReHso0QgX5aT-hGySUnlrQ) (pwd: 4r98) | +| text2sql-t5-3b | [Google Drive link](https://drive.google.com/file/d/1M-zVeB6TKrvcIzaH8vHBIKeWqPn95i11/view?usp=sharing) | [Link](https://pan.baidu.com/s/1mZxakfes4wRSEwnRW43i5A) (pwd: sc62) | +| text2natsql-t5-large | [Google Drive link](https://drive.google.com/file/d/1ZwFsH24_qKC3xwYdedPi6T_8argguWHe/view?usp=sharing) | [Link](https://pan.baidu.com/s/18H8lgnv9gfXmUo_oO_CdOA) (pwd: 7iyq) | +| text2sql-t5-large | [Google Drive link](https://drive.google.com/file/d/1-xwtKwfJZSrmJrU-_Xdkx1kPuZao7r7e/view?usp=sharing) | [Link](https://pan.baidu.com/s/1Mwg0OZZ48APEq9jPvQQNtw) (pwd: q58k) | +| text2natsql-t5-base | [Google Drive link](https://drive.google.com/file/d/1QyfSfHHrxfIM5X9gKUYNr_0ZRVvb1suV/view?usp=share_link) | [Link](https://pan.baidu.com/s/1XegaZFvXuZ_jf3P-9YPQCQ) (pwd: pyxf) | +| text2sql-t5-base | [Google Drive link](https://drive.google.com/file/d/1lqZ81f_fSZtg6BRcRw1-Ol-RJCcKRsmH/view?usp=sharing) | [Link](https://pan.baidu.com/s/1-6H7zStq0WCJHTjDuVspoQ) (pwd: wuek) | +| text2natsql-mt5-xl-cspider (trained on CSpider) | - | [Link](https://pan.baidu.com/s/1tFkGOiw5ETB83-Ct3MuVXA) (pwd: y7ei) | +| text2natsql-mt5-large-cspider (trained on CSpider) | - | [Link](https://pan.baidu.com/s/1LUjL-2nwNfUJhzI3cm7aEQ) (pwd: ydqk) | +| text2natsql-mt5-base-cspider (trained on CSpider) | - | [Link](https://pan.baidu.com/s/1tbEUIBPUTA2Oz7K2lHT9oA) (pwd: d8b8) | + +The checkpoints should be placed in the `models` folder. For CSpider, we only provide the NatSQL version because its performance is better than SQL in our pre-experiments. To support CSpider, we replace roberta-large with xlm-roberta-large in the first stage and replace t5 with mt5 in the second stage. ### Step2: Run Inference -The inference scripts are located in `scripts/inference`. + +The inference scripts are located in `scripts/inference`. Concretely, `infer_text2natsql.sh` is the inference script of RESDSQL-{Base, Large, 3B}+NatSQL, and `infer_text2sql.sh` is the inference script of RESDSQL-{Base, Large, 3B}. For example, you can run the inference of RESDSQL-3B+NatSQL on Spider's dev set via: + ```sh sh scripts/inference/infer_text2natsql.sh 3b spider ``` + The first argument (model scale) can be selected from `[base, large, 3b]` and the second argument (dataset name) can be selected from `[spider, spider-realistic, spider-syn, spider-dk, DB_schema_synonym, DB_schema_abbreviation, DB_DBcontent_equivalence, NLQ_keyword_synonym, NLQ_keyword_carrier, NLQ_column_synonym, NLQ_column_carrier, NLQ_column_attribute, NLQ_column_value, NLQ_value_synonym, NLQ_multitype, NLQ_others, SQL_comparison, SQL_sort_order, SQL_NonDB_number, SQL_DB_text, SQL_DB_number]`. The predicted SQL queries are recorded in `predictions/{dataset_name}/{model_name}/pred.sql`. **Inference on CSpider's Dev Set (New Feature)** We also provide inference scripts to run RESDSQL-{Base, Large, 3B}+NatSQL on CSpider's development set. Here is an example: + ```sh sh scripts/inference/infer_text2natsql_cspider.sh 3b ``` + The first argument (model scale) can be selected from `[base, large, 3b]`. ## Training on Spider + We provide scripts in `scripts/train/text2natsql` and `scripts/train/text2sql` to train RESDSQL on Spider's training set and evaluate on Spider's dev set. **RESDSQL-{Base, Large, 3B}+NatSQL** + ```sh # Step1: preprocess dataset sh scripts/train/text2natsql/preprocess.sh @@ -184,6 +213,7 @@ sh scripts/train/text2natsql/train_text2natsql_t5_base.sh ``` **RESDSQL-{Base, Large, 3B}** + ```sh # Step1: preprocess dataset sh scripts/train/text2sql/preprocess.sh @@ -202,6 +232,7 @@ sh scripts/train/text2sql/train_text2sql_t5_base.sh **During training, the cross-encoder (i.e., the first stage) always keeps the best checkpoint, but T5 (i.e., the second stage) keeps all the intermediate checkpoints, because different test sets may achieve the best Text-to-SQL performance on different checkpoints**. Therefore, given a test set, we need to evaluate all the intermediate checkpoints and compare their performance to find the best checkpoint. The evaluation results of checkpoints are saved in `eval_results`. Our paper also report the performence of RESDSQL-3B+NatSQL (the most powerful version of RESDSQL) on Spider-DK, Spider-Syn, and Spider-Realistic. To obtain results on these datasets, we provide evaluation scripts in `scripts/evaluate_robustness`. Here is an example for Spider-DK: + ```sh # Step1: preprocess Spider-DK sh scripts/evaluate_robustness/preprocess_spider_dk.sh @@ -210,9 +241,11 @@ sh scripts/evaluate_robustness/evaluate_on_spider_dk.sh ``` ## Training on CSpider + We additionally provide scripts in `scripts/train/cspider_text2natsql` and `scripts/train/cspider_text2sql` to train RESDSQL on CSpider's training set and evaluate on CSpider's dev set. **RESDSQL-{Base, Large, 3B}+NatSQL (CSpider version)** + ```sh # Step1: preprocess CSpider sh scripts/train/cspider_text2natsql/preprocess.sh @@ -231,6 +264,7 @@ sh scripts/train/cspider_text2natsql/train_text2natsql_mt5_base.sh In order to train the NatSQL version on CSpider, we manually aligned and modified annotations of NatSQL. The aligned files are also released, see `NatSQL/NatSQLv1_6/train_cspider-natsql.json` and `NatSQL/NatSQLv1_6/dev_cspider-natsql.json`. **RESDSQL-{Base, Large, 3B} (CSpider version)** + ```sh # Step1: preprocess CSpider sh scripts/train/cspider_text2sql/preprocess.sh @@ -246,6 +280,6 @@ sh scripts/train/cspider_text2sql/train_text2sql_mt5_large.sh sh scripts/train/cspider_text2sql/train_text2sql_mt5_base.sh ``` - ## Acknowledgements + We would thanks to Hongjin Su and Tao Yu for their help in evaluating our method on Spider's test set. We would also thanks to PICARD ([paper](https://arxiv.org/abs/2109.05093), [code](https://github.com/ServiceNow/picard)), NatSQL ([paper](https://arxiv.org/abs/2109.05153), [code](https://github.com/ygan/NatSQL)), Spider ([paper](https://arxiv.org/abs/1809.08887), [dataset](https://yale-lily.github.io/spider)), Spider-DK ([paper](https://arxiv.org/abs/2109.05157), [dataset](https://github.com/ygan/Spider-DK)), Spider-Syn ([paper](https://arxiv.org/abs/2106.01065), [dataset](https://github.com/ygan/Spider-Syn)), Spider-Realistic ([paper](https://arxiv.org/abs/2010.12773), [dataset](https://doi.org/10.5281/zenodo.5205322)), Dr.Spider ([paper](https://openreview.net/pdf?id=Wc5bmZZU9cy), [dataset](https://github.com/awslabs/diagnostic-robustness-text-to-sql)), and CSpider ([paper](https://arxiv.org/abs/1909.13293), [dataset](https://taolusi.github.io/CSpider-explorer/)) for their interesting work and open-sourced code and dataset. diff --git a/requirements.txt b/requirements.txt index a265223..1a4f7b0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ nltk==3.7 numpy==1.22.3 rapidfuzz==2.0.11 scikit_learn==1.2.1 -spacy==2.2.3 +spacy==3.7.2 sql_metadata==2.6.0 sqlparse==0.4.2 tokenizers==0.11.6