This repository contains the entire code for our TWC work "Iterative Algorithm Induced Deep-Unfolding Neural Networks: Precoding Design for Multiuser MIMO Systems", available at: https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=9246287 and has been accepted for publication in TWC.
For any reproduce, further research or development, please kindly cite our TWC Journal paper:
Q. Hu, Y. Cai, Q. Shi, K. Xu, G. Yu, and Z. Ding, “Iterative algorithm induced deep-unfolding neural networks: Precoding design for multiuser MIMO systems,” IEEE Trans. Wireless Commun., to be published, DOI: 10.1109/TWC.2020.3033334.
The following versions have been tested: Python 3.6 + Tensorflow 1.12. But newer versions should also be fine.
There are three folders: "DeepUnfolding", "Blackbox CNN", and "WMMSE", where "DeepUnfolding" corresponds to the proposed deep-unfolding network in our paper, "Blackbox CNN" and "WMMSE" are benchmarks.
Run the main program "train_main.py".
train_main.py: Main program that implements the training and testing stages;
objective_func.py: The sum-rate (loss) function;
UW_gradient.py: The gradients of the variables (U and W) in the last layer of the deep-unfolding neural network;
UW_conj_gradient.py: The conjugate gradients of the variables (U and W) in the last layer of the deep-unfolding neural network;
generate_channel.py: Generate the complex Gaussian channels, which could be modified to other channel models;
test_model.py: Import the trained model and test its performance.
Firstly, we run "generate_data.m" in the folder "GenerateData" to generate the training dataset, which consists of the inputs in the file "Input_H.csv" and the labels in the file "Output_UW.csv". The two files are generated in the folder "mat" and should be copied into the folder "DataSet".
Run "Blackbox CNN.py", which generates four ".csv" files in the folder "DataSet". Finally, we run the file "test_predict.m" in the folder "Test" to see the sum-rate performance. Note that the file path in "test_predict.m" should be modified correspondingly.
Run the main program "WMMSE.py", which implements the iterative WMMSE algorithm.