Skip to content
This repository was archived by the owner on Feb 12, 2022. It is now read-only.

Commit 001f9c5

Browse files
committed
Add JSON API.
Also optimise Dockerfile for quick rebuilds. Also add multiple input files support to `predict.py`.
1 parent ae15e1a commit 001f9c5

File tree

6 files changed

+130
-104
lines changed

6 files changed

+130
-104
lines changed

Dockerfile

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
FROM conda/miniconda3
22

3-
RUN apt-get update && \
4-
apt-get install -y libsndfile1
3+
RUN apt update && apt install -y g++
4+
5+
# Copy requirements.txt and run pip first so that changes to the application
6+
# code do not require a rebuild of the entire image
7+
COPY requirements.txt /app/
8+
RUN conda update conda && \
9+
conda install "keras<2.4" "numpy<2" "scikit-learn<0.23" && \
10+
conda install -c conda-forge librosa theano
511

612
ADD . /app
713
WORKDIR /app
814

915
VOLUME /data
1016

11-
RUN pip install --upgrade pip && \
12-
pip install -r requirements.txt
17+
ENV KERAS_BACKEND=theano

README.md

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ This repository provides the [keras](https://keras.io/) model to be used from Py
3939
[Docker](https://www.docker.com/) makes it easy to reproduce the results and install all requirements. If you have docker installed, run the following steps to predict a count from the provided test sample.
4040

4141
* Build the docker image: `docker build -t countnet .`
42-
* Predict from example: `docker run -i countnet python predict.py --model CRNN examples/5_speakers.wav`
42+
* Run like this: `docker run -it countnet python predict.py ...` (see usage details below)
43+
* Mount your data into the container: `docker run -v /path/to/your/data:/data -it countnet python predict.py ... /data/your_audio.wav`
4344

4445
### Manual Installation
4546

@@ -49,7 +50,46 @@ To install the requirements using Anaconda Python, run
4950

5051
You can now run the command line script and process wav files using the pre-trained model `CRNN` (best peformance).
5152

52-
`python predict.py examples/5_speakers.wav --model CRNN`
53+
```
54+
python predict.py --model CRNN examples/5_speakers.wav
55+
# => Speaker Count Estimate: examples/5_speakers.wav 5
56+
```
57+
58+
You can also pass multiple files at once.
59+
60+
```
61+
python predict.py --model CRNN examples/5_speakers.wav examples/5_speakers.wav
62+
# => Speaker Count Estimate: examples/5_speakers.wav 5
63+
# => Speaker Count Estimate: examples/5_speakers.wav 5
64+
```
65+
66+
There is also a simple JSON API to send audio data to (not production ready; only for development!). To run the server:
67+
68+
```
69+
python predict_api.py --model CRNN
70+
71+
# With Docker:
72+
docker run -p5000:5000 -it countnet python predict_api.py --model CRNN
73+
```
74+
75+
The server expects a JSON list of base64 encoded arrays of 16 kHz, float32 audio arrays. It returns a JSON list of integers. If estimation failed for any of the arrays, its result is set to `null` instead.
76+
77+
```py
78+
import base64
79+
import requests
80+
import librosa
81+
82+
audio_data1 = librosa.core.load("/path/to/5_speakers.wav", sr=16000, dtype="float32")[0]
83+
response = requests.post(
84+
"http://localhost:5000",
85+
json=[
86+
base64.b64encode(audio_data1.tobytes())
87+
]
88+
)
89+
print(response.json())
90+
# => [5]
91+
```
92+
5393

5494
## Reproduce Paper Results using the LibriCount Dataset
5595
[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.1216072.svg)](https://doi.org/10.5281/zenodo.1216072)

predict.py

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import numpy as np
2-
import soundfile as sf
32
import argparse
43
import os
54
import keras
@@ -20,6 +19,25 @@ def class_mae(y_true, y_pred):
2019
)
2120

2221

22+
def load_scaler():
23+
scaler = sklearn.preprocessing.StandardScaler()
24+
with np.load(os.path.join("models", 'scaler.npz')) as data:
25+
scaler.mean_ = data['arr_0']
26+
scaler.scale_ = data['arr_1']
27+
return scaler
28+
29+
30+
def load_model(model_name):
31+
path = os.path.join('models', model_name + '.h5')
32+
return keras.models.load_model(
33+
path,
34+
custom_objects={
35+
'class_mae': class_mae,
36+
'exp': K.exp
37+
}
38+
)
39+
40+
2341
def count(audio, model, scaler):
2442
# compute STFT
2543
X = np.abs(librosa.stft(audio, n_fft=400, hop_length=160)).T
@@ -51,38 +69,31 @@ def count(audio, model, scaler):
5169

5270
parser.add_argument(
5371
'audio',
54-
help='audio file (samplerate 16 kHz) of 5 seconds duration'
72+
help='audio file (samplerate 16 kHz) of 5 seconds duration',
73+
nargs='+',
5574
)
5675

5776
parser.add_argument(
5877
'--model', default='CRNN',
5978
help='model name'
6079
)
6180

81+
parser.add_argument('--print-summary', action='store_true')
82+
6283
args = parser.parse_args()
6384

6485
# load model
65-
model = keras.models.load_model(
66-
os.path.join('models', args.model + '.h5'),
67-
custom_objects={
68-
'class_mae': class_mae,
69-
'exp': K.exp
70-
}
71-
)
86+
model = load_model(args.model)
7287

73-
# print model configuration
74-
model.summary()
75-
# save as svg file
76-
# load standardisation parameters
77-
scaler = sklearn.preprocessing.StandardScaler()
78-
with np.load(os.path.join("models", 'scaler.npz')) as data:
79-
scaler.mean_ = data['arr_0']
80-
scaler.scale_ = data['arr_1']
88+
if args.print_summary:
89+
# print model configuration
90+
model.summary()
8191

82-
# compute audio
83-
audio, rate = sf.read(args.audio, always_2d=True)
92+
# load standardisation parameters
93+
scaler = load_scaler()
8494

85-
# downmix to mono
86-
audio = np.mean(audio, axis=1)
87-
estimate = count(audio, model, scaler)
88-
print("Speaker Count Estimate: ", estimate)
95+
for f in args.audio:
96+
# compute audio
97+
audio = librosa.load(f, sr=16000)[0]
98+
estimate = count(audio, model, scaler)
99+
print("Speaker Count Estimate:", f, estimate)

predict_api.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import base64
2+
import json
3+
import numpy as np
4+
from werkzeug.wrappers import Request, Response
5+
import predict
6+
7+
8+
def decode_audio(audio_bytes):
9+
return np.frombuffer(base64.b64decode(audio_bytes), dtype="float32")
10+
11+
12+
def make_app(estimate_func):
13+
def app(environ, start_response):
14+
inputs = json.loads(Request(environ).get_data())
15+
16+
outputs = []
17+
for inp in inputs:
18+
try:
19+
est = int(estimate_func(decode_audio(inp)))
20+
except Exception as e:
21+
print(f"Error estimating speaker count for input {len(outputs)}: {e}")
22+
est = None
23+
outputs.append(est)
24+
25+
return Response(json.dumps(outputs))(environ, start_response)
26+
27+
return app
28+
29+
30+
if __name__ == "__main__":
31+
import argparse
32+
import functools
33+
from werkzeug.serving import run_simple
34+
35+
parser = argparse.ArgumentParser(
36+
description="Run simple JSON api server to predict speaker count"
37+
)
38+
parser.add_argument("--model", default="CRNN", help="model name")
39+
args = parser.parse_args()
40+
41+
model = predict.load_model(args.model)
42+
scaler = predict.load_scaler()
43+
44+
app = make_app(functools.partial(predict.count, model=model, scaler=scaler))
45+
run_simple("0.0.0.0", 5000, app, use_debugger=True)

predict_audio.py

Lines changed: 0 additions & 67 deletions
This file was deleted.

requirements.txt

Lines changed: 0 additions & 8 deletions
This file was deleted.

0 commit comments

Comments
 (0)