-
Notifications
You must be signed in to change notification settings - Fork 14
Expand file tree
/
Copy pathgenomic_clustering.py
More file actions
68 lines (47 loc) · 2.03 KB
/
genomic_clustering.py
File metadata and controls
68 lines (47 loc) · 2.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import sys
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.ml.linalg import Vectors
from pyspark.ml.clustering import KMeans
encodeGenotype = lambda genotype: float(sum( allele!='0' for allele in genotype.split('|')))
encodeGenotypes = lambda genotypes: Vectors.dense(map(encodeGenotype, genotypes))
def splitGenotypeLine(line):
columns = line.split()
return columns[0:9] + [ encodeGenotypes(columns[9:]) ]
def usage():
print("genomic_clustering.py input-vcf-file output-csv-file [k-clusters = 10] [spark-pararellism=def]")
sys.exit(1)
def main(argv):
if len(argv) < 2:
usage()
inputPath = argv[0]
outputPath = argv[1]
kClusters = int(argv[2]) if len(argv) > 2 else 10
sparkPar = int(argv[3]) if len(argv) > 3 else None
print("Running with input: %s, output: %s, kClusters: %s, sparkPar: %s" % (inputPath, outputPath, kClusters, sparkPar))
spark = SparkSession.builder \
.appName("Genomic Clustering") \
.getOrCreate()
sc = spark.sparkContext
chr22RDD = sc.textFile(inputPath, sparkPar or sc.defaultMinPartitions)
print("Loaded %s with %s partitions" % (inputPath, chr22RDD.getNumPartitions()))
print(chr22RDD.take(10))
header = chr22RDD.filter(lambda line: line.startswith("#CHROM")).map(lambda line:line.split()).first()
print(header[0:10])
columnNames = header[0:9]
sampleNames = header[9:]
print(sampleNames[0:10])
df = spark.createDataFrame(chr22RDD.filter(lambda line: not line.startswith("#")).map(splitGenotypeLine),
schema = columnNames + ['encoded_genotypes'])
df.cache()
print("DF Count: %s" % df.count())
df.printSchema()
df.limit(10).show()
kMeans = KMeans(featuresCol='encoded_genotypes', k=kClusters, initMode='random')
kMeansModel = kMeans.fit(df)
clusterCentersPD = pd.DataFrame.from_records(kMeansModel.clusterCenters()).T
clusterCentersPD.insert(0, 'Sample ID', sampleNames)
spark.createDataFrame(clusterCentersPD).coalesce(1) \
.write.csv(path=outputPath, mode='overwrite', header=True)
if __name__ == "__main__":
main(sys.argv[1:])