Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 52 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
**This project is a prototype for experimental purposes only and production grade code is not released here.**

# Deep LSTM siamese network for text similarity

It is a tensorflow based implementation of deep siamese LSTM network to capture phrase/sentence similarity using character embeddings.
Expand All @@ -7,15 +9,15 @@ This code provides architecture for learning two kinds of tasks:
- Phrase similarity using char level embeddings [1]
![siamese lstm phrase similarity](https://cloud.githubusercontent.com/assets/9861437/20479454/405a1aea-b004-11e6-8a27-7bb05cf0a002.png)

- Sentence similarity using char+word level embeddings [2]
- Sentence similarity using word level embeddings [2]
![siamese lstm sentence similarity](https://cloud.githubusercontent.com/assets/9861437/20479493/6ea8ad12-b004-11e6-89e4-53d4d354d32e.png)

For both the tasks mentioned above it uses a multilayer siamese LSTM network and euclidian distance based contrastive loss to learn input pair similairty.

# Capabilities
Given adequate training pairs, this model can learn Semantic as well as structural similarity. For eg:

**Phrase :**
**Phrases :**
- International Business Machines = I.B.M
- Synergy Telecom = SynTel
- Beam inc = Beam Incorporate
Expand All @@ -24,14 +26,18 @@ Given adequate training pairs, this model can learn Semantic as well as structur
- James B. D. Joshi = James Joshi
- James Beaty, Jr. = Beaty

**Sentence :**
For phrases, the model learns **character based embeddings** to identify structural/syntactic similarities.

**Sentences :**
- He is smart = He is a wise man.
- Someone is travelling countryside = He is travelling to a village.
- She is cooking a dessert = Pudding is being cooked.
- Microsoft to acquire Linkedin ≠ Linkedin to acquire microsoft

(More examples Ref: semEval dataset)

For Sentences, the model uses **pre-trained word embeddings** to identify semantic similarities.

Categories of pairs, it can learn as similar:
- Annotations
- Abbreviations
Expand All @@ -42,8 +48,29 @@ Categories of pairs, it can learn as similar:
- Summaries

# Training Data
A sample set of learning person name paraphrases have been attached to this repository. To generate full person name disambiguation data follow the steps mentioned at:
> https://github.com/dhwajraj/dataset-person-name-disambiguation
- **Phrases:**
- A sample set of learning person name paraphrases have been attached to this repository. To generate full person name disambiguation data follow the steps mentioned at:

> https://github.com/dhwajraj/dataset-person-name-disambiguation

"person_match.train" : https://drive.google.com/open?id=1HnMv7ulfh8yuq9yIrt_IComGEpDrNyo-
- **Sentences:**
- A sample set of learning sentence semantic similarity can be downloaded from:

"train_snli.txt" : https://drive.google.com/open?id=1itu7IreU_SyUSdmTWydniGxW-JEGTjrv

This data is generated using SNLI project :
> https://nlp.stanford.edu/projects/snli/

- word embeddings: any set of pre-trained word embeddings can be utilized in this project. For our testing we had used fastText simple english embeddings from https://github.com/facebookresearch/fastText/blob/master/pretrained-vectors.md

alternate download location for "wiki.simple.vec" is : https://drive.google.com/open?id=1u79f3d2PkmePzyKgubkbxOjeaZCJgCrt

# Environment
- numpy 1.11.0
- tensorflow 1.2.1
- gensim 1.0.1
- nltk 3.2.2

# How to run
### Training
Expand All @@ -52,6 +79,16 @@ $ python train.py [options/defaults]

options:
-h, --help show this help message and exit
--is_char_based IS_CHAR_BASED
is character based syntactic similarity to be used for phrases.
if false then word embedding based semantic similarity is used.
(default: True)
--word2vec_model WORD2VEC_MODEL
this flag will be used only if IS_CHAR_BASED is False
word2vec pre-trained embeddings file (default: wiki.simple.vec)
--word2vec_format WORD2VEC_FORMAT
this flag will be used only if IS_CHAR_BASED is False
word2vec pre-trained embeddings file format (bin/text/textgz)(default: text)
--embedding_dim EMBEDDING_DIM
Dimensionality of character embedding (default: 100)
--dropout_keep_prob DROPOUT_KEEP_PROB
Expand Down Expand Up @@ -89,11 +126,18 @@ $ python eval.py --model graph#.pb
```

# Performance
**Phrases:**
- Training time: (8 core cpu) = 1 complete epoch : 6min 48secs (training requires atleast 30 epochs)
- Contrastive Loss : 0.0248
- Evaluation performance : similarity measure for 100,000 pairs (8core cpu) = 1min 40secs
- Accuracy 91%

# Refrerences

**Sentences:**
- Training time: (8 core cpu) = 1 complete epoch : 8min 10secs (training requires atleast 50 epochs)
- Contrastive Loss : 0.0477
- Evaluation performance : similarity measure for 100,000 pairs (8core cpu) = 2min 10secs
- Accuracy 81%

# References
1. [Learning Text Similarity with Siamese Recurrent Networks](http://www.aclweb.org/anthology/W16-16#page=162)
2. [Siamese Recurrent Architectures for Learning Sentence Similarity](http://www.aaai.org/Conferences/AAAI/2016/Papers/15Mueller12195.pdf)
2. [Siamese Recurrent Architectures for Learning Sentence Similarity](http://www.mit.edu/~jonasm/info/MuellerThyagarajan_AAAI16.pdf)
10 changes: 5 additions & 5 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
# Eval Parameters
tf.flags.DEFINE_integer("batch_size", 64, "Batch Size (default: 64)")
tf.flags.DEFINE_string("checkpoint_dir", "", "Checkpoint directory from training run")
tf.flags.DEFINE_string("eval_filepath", "match_valid.tsv", "Evaluate on this data (Default: None)")
tf.flags.DEFINE_string("vocab_filepath", "runs/1479874609/checkpoints/vocab", "Load training time vocabulary (Default: None)")
tf.flags.DEFINE_string("model", "runs/1479874609/checkpoints/model-32000", "Load trained model checkpoint (Default: None)")
tf.flags.DEFINE_string("eval_filepath", "validation.txt0", "Evaluate on this data (Default: None)")
tf.flags.DEFINE_string("vocab_filepath", "runs/1512222837/checkpoints/vocab", "Load training time vocabulary (Default: None)")
tf.flags.DEFINE_string("model", "runs/1512222837/checkpoints/model-5000", "Load trained model checkpoint (Default: None)")

# Misc Parameters
tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement")
Expand Down Expand Up @@ -77,10 +77,10 @@
all_d=[]
for db in batches:
x1_dev_b,x2_dev_b,y_dev_b = zip(*db)
batch_predictions, batch_acc, sim = sess.run([predictions,accuracy,sim], {input_x1: x1_dev_b, input_x2: x2_dev_b, input_y:y_dev_b, dropout_keep_prob: 1.0})
batch_predictions, batch_acc, batch_sim = sess.run([predictions,accuracy,sim], {input_x1: x1_dev_b, input_x2: x2_dev_b, input_y:y_dev_b, dropout_keep_prob: 1.0})
all_predictions = np.concatenate([all_predictions, batch_predictions])
print(batch_predictions)
all_d = np.concatenate([all_d, sim])
all_d = np.concatenate([all_d, batch_sim])
print("DEV acc {}".format(batch_acc))
for ex in all_predictions:
print ex
Expand Down
Loading