Skip to content

WhittJS/RNN-Pretraining-Java

Repository files navigation

RNN-Pretraining-Java

The goal of this project is to pre-train an vanilla RNN to learn the general syntax of Java code, without doing any specific downstream task. The data used to train the model is a constructed masked language modeling (MLM) dataset based on a collection of Java methods. Each method was tokenized using javalang to split the code into individual tokens. 15% of the tokens were randomly selected to be masked. For each masked token, we generated a row in the dataset with the original input, the prepared input (without the masked token), and the output (the token that was masked). Not every possible masked instance was included for each method in the dataset for simplicity's sake.

For the evaluation metrics, loss and accuracy were calculated. Loss was calculated using built-in pytorch cross-entropy loss. Accuracy was manually calculated comparing the predicted tokens (using argmax) with the target tokens, which is the percentage of the tokens that the model accurately predicted. The best model results are at epoch 2, when the loss is minimal. This model doesn't train very well or lead to great performance because RNNs cannot deal with long-term dependencies. This problem has been solved by LSTMs, which are a type of RNN.

Installation:

  1. Install python 3.9+ locally
  2. Clone the repository to your workspace:
~ $ git clone https://github.com/WhittJS/RNN-Pretraining-Java.git
  1. Navigate into the repository:
~ $ cd RNN-Pretraining-Java
~/RNN-Pretraining-Java $
  1. Set up a virtual environment and activate it:
~/RNN-Pretraining-Java $ python -m venv ./venv/

For macOS/Linux:

~/RNN-Pretraining-Java $ source venv/bin/activate
(venv) ~/RNN-Pretraining-Java $ 

For Windows:

~\RNN-Pretraining-Java $ .\venv\Scripts\activate.bat
(venv) ~\RNN-Pretraining-Java $ 
  1. To install the required packages:
(venv) ~/RNN-Pretraining-Java $ pip install -r requirements.txt

Running the Program

  1. To train the model, run:
python rnn-pretraining.py
  • The model is saved under models/dl_code_completion.pth
  • The results are saved under training_logs.csv
  1. The hyperparameters can be adjusted on lines 198-204, which will change the results.

About

Train an RNN to learn the syntax of Java code.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages