Skip to content

Commit 96c2bb6

Browse files
authored
feat: read spark.master and set correct sql file path (#1996)
* Read spark.master and set correct sql file path * Reuse util methods for all batch job
1 parent 79143ed commit 96c2bb6

File tree

6 files changed

+80
-51
lines changed

6 files changed

+80
-51
lines changed

java/openmldb-batchjob/src/main/scala/com/_4paradigm/openmldb/batchjob/ExportOfflineData.scala

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,19 @@
1616

1717
package com._4paradigm.openmldb.batchjob
1818

19-
import com._4paradigm.openmldb.batch.api.OpenmldbSession
20-
import org.apache.spark.SparkFiles
19+
import com._4paradigm.openmldb.batchjob.util.OpenmldbJobUtil
2120
import org.apache.spark.sql.SparkSession
2221

2322
object ExportOfflineData {
24-
def main(args: Array[String]): Unit = {
25-
if (args.length < 1) {
26-
throw new Exception(s"Require args: sql but get args: ${args.mkString(",")}")
27-
}
2823

24+
def main(args: Array[String]): Unit = {
25+
OpenmldbJobUtil.checkOneSqlArgument(args)
2926
exportOfflineData(args(0))
3027
}
3128

3229
def exportOfflineData(sqlFilePath: String): Unit = {
33-
val sess = new OpenmldbSession(SparkSession.builder().getOrCreate())
34-
35-
val sqlText = scala.io.Source.fromFile(SparkFiles.get(sqlFilePath)).mkString
36-
sess.sql(sqlText)
37-
38-
sess.close()
30+
val spark = SparkSession.builder().getOrCreate()
31+
OpenmldbJobUtil.runOpenmldbSql(spark, sqlFilePath)
3932
}
33+
4034
}

java/openmldb-batchjob/src/main/scala/com/_4paradigm/openmldb/batchjob/ImportOfflineData.scala

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,19 @@
1616

1717
package com._4paradigm.openmldb.batchjob
1818

19-
import com._4paradigm.openmldb.batch.api.OpenmldbSession
20-
import org.apache.spark.SparkFiles
19+
import com._4paradigm.openmldb.batchjob.util.OpenmldbJobUtil
2120
import org.apache.spark.sql.SparkSession
2221

2322
object ImportOfflineData {
2423

2524
def main(args: Array[String]): Unit = {
26-
if (args.length < 1) {
27-
throw new Exception(s"Require args: sql but get args: ${args.mkString(",")}")
28-
}
29-
25+
OpenmldbJobUtil.checkOneSqlArgument(args)
3026
importOfflineData(args(0))
3127
}
3228

3329
def importOfflineData(sqlFilePath: String): Unit = {
34-
val sess = new OpenmldbSession(SparkSession.builder().getOrCreate())
35-
36-
val sqlText = scala.io.Source.fromFile(SparkFiles.get(sqlFilePath)).mkString
37-
sess.sql(sqlText)
38-
39-
sess.close()
30+
val spark = SparkSession.builder().getOrCreate()
31+
OpenmldbJobUtil.runOpenmldbSql(spark, sqlFilePath)
4032
}
4133

4234
}

java/openmldb-batchjob/src/main/scala/com/_4paradigm/openmldb/batchjob/ImportOnlineData.scala

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,19 @@
1616

1717
package com._4paradigm.openmldb.batchjob
1818

19-
import com._4paradigm.openmldb.batch.api.OpenmldbSession
20-
import org.apache.spark.SparkFiles
19+
import com._4paradigm.openmldb.batchjob.util.OpenmldbJobUtil
2120
import org.apache.spark.sql.SparkSession
2221

2322
object ImportOnlineData {
2423

2524
def main(args: Array[String]): Unit = {
26-
if (args.length < 1) {
27-
throw new Exception(s"Require args: sql but get args: ${args.mkString(",")}")
28-
}
29-
25+
OpenmldbJobUtil.checkOneSqlArgument(args)
3026
importOnlineData(args(0))
3127
}
3228

3329
def importOnlineData(sqlFilePath: String): Unit = {
34-
val sess = new OpenmldbSession(SparkSession.builder().config("openmldb.loaddata.mode", "online").getOrCreate())
35-
36-
val sqlText = scala.io.Source.fromFile(SparkFiles.get(sqlFilePath)).mkString
37-
sess.sql(sqlText)
38-
39-
sess.close()
30+
val spark = SparkSession.builder().config("openmldb.loaddata.mode", "online").getOrCreate()
31+
OpenmldbJobUtil.runOpenmldbSql(spark, sqlFilePath)
4032
}
4133

4234
}

java/openmldb-batchjob/src/main/scala/com/_4paradigm/openmldb/batchjob/RunBatchAndShow.scala

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,22 @@
1717
package com._4paradigm.openmldb.batchjob
1818

1919
import com._4paradigm.openmldb.batch.api.OpenmldbSession
20-
import org.apache.spark.SparkFiles
20+
import com._4paradigm.openmldb.batchjob.util.OpenmldbJobUtil
2121
import org.apache.spark.sql.SparkSession
2222

2323
object RunBatchAndShow {
2424

2525
def main(args: Array[String]): Unit = {
26-
if (args.length < 1) {
27-
throw new Exception(s"Require args: sql but get args: ${args.mkString(",")}")
28-
}
29-
26+
OpenmldbJobUtil.checkOneSqlArgument(args)
3027
runBatchSql(args(0))
3128
}
3229

3330
def runBatchSql(sqlFilePath: String): Unit = {
34-
val sess = new OpenmldbSession(SparkSession.builder().getOrCreate())
31+
val spark = SparkSession.builder().getOrCreate()
32+
val sqlText = OpenmldbJobUtil.getSqlFromFile(spark, sqlFilePath)
3533

36-
val sqlText = scala.io.Source.fromFile(SparkFiles.get(sqlFilePath)).mkString
34+
val sess = new OpenmldbSession(spark)
3735
sess.sql(sqlText).show()
38-
3936
sess.close()
4037
}
4138

java/openmldb-batchjob/src/main/scala/com/_4paradigm/openmldb/batchjob/RunBatchSql.scala

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,23 @@
1717
package com._4paradigm.openmldb.batchjob
1818

1919
import com._4paradigm.openmldb.batch.api.OpenmldbSession
20+
import com._4paradigm.openmldb.batchjob.util.OpenmldbJobUtil
2021
import org.apache.spark.SparkFiles
2122
import org.apache.spark.sql.SparkSession
2223

2324
object RunBatchSql {
2425

2526
def main(args: Array[String]): Unit = {
26-
if (args.length < 1) {
27-
throw new Exception(s"Require args: sql but get args: ${args.mkString(",")}")
28-
}
29-
27+
OpenmldbJobUtil.checkOneSqlArgument(args)
3028
runBatchSql(args(0))
3129
}
3230

3331
def runBatchSql(sqlFilePath: String): Unit = {
34-
val sess = new OpenmldbSession(SparkSession.builder().getOrCreate())
32+
val spark = SparkSession.builder().getOrCreate()
33+
val sqlText = OpenmldbJobUtil.getSqlFromFile(spark, sqlFilePath)
3534

36-
val sqlText = scala.io.Source.fromFile(SparkFiles.get(sqlFilePath)).mkString
35+
val sess = new OpenmldbSession(spark)
3736
sess.sql(sqlText).show()
38-
3937
sess.close()
4038
}
4139

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright 2021 4Paradigm
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com._4paradigm.openmldb.batchjob.util
18+
19+
import com._4paradigm.openmldb.batch.api.OpenmldbSession
20+
import org.apache.spark.SparkFiles
21+
import org.apache.spark.sql.SparkSession
22+
import scala.reflect.io.File
23+
24+
object OpenmldbJobUtil {
25+
26+
def checkOneSqlArgument(args: Array[String]): Unit = {
27+
if (args.length != 1) {
28+
throw new Exception(s"Require args of sql but get args: ${args.mkString(",")}")
29+
}
30+
}
31+
32+
def getSqlFromFile(spark: SparkSession, sqlFilePath: String): String = {
33+
val sparkMaster = spark.conf.get("spark.master")
34+
35+
val actualSqlFilePath = if (sparkMaster.equals("local")) {
36+
SparkFiles.get(sqlFilePath)
37+
} else {
38+
sqlFilePath
39+
}
40+
41+
if (!File(actualSqlFilePath).exists) {
42+
throw new Exception("SQL file does not exist in " + actualSqlFilePath)
43+
}
44+
45+
scala.io.Source.fromFile(actualSqlFilePath).mkString
46+
}
47+
48+
def runOpenmldbSql(spark: SparkSession, sqlFilePath: String): Unit = {
49+
val sqlText = OpenmldbJobUtil.getSqlFromFile(spark, sqlFilePath)
50+
51+
val sess = new OpenmldbSession(spark)
52+
sess.sql(sqlText)
53+
sess.close()
54+
}
55+
56+
}

0 commit comments

Comments
 (0)