11package main
22
33import (
4+ "fmt"
45 "net/url"
56 "strings"
7+ "sync"
68
79 "github.com/gofiber/fiber/v2/utils"
810
@@ -11,6 +13,9 @@ import (
1113 "github.com/gofiber/fiber/v2"
1214)
1315
16+ // Maximum number of batch pushes allowed, -1 means no limit
17+ var maxBatchPushCount = - 1
18+
1419func init () {
1520 // V2 API
1621 registerRoute ("push" , func (router fiber.Router ) {
@@ -33,6 +38,11 @@ func init() {
3338 })
3439}
3540
41+ // Set the maximum number of batch pushes allowed
42+ func SetMaxBatchPushCount (count int ) {
43+ maxBatchPushCount = count
44+ }
45+
3646func routeDoPush (c * fiber.Ctx ) error {
3747 // Get content-type
3848 contentType := utils .ToLower (utils .UnsafeString (c .Request ().Header .ContentType ()))
@@ -61,7 +71,12 @@ func routeDoPush(c *fiber.Ctx) error {
6171 }
6272 }
6373
64- return push (c , params )
74+ code , err := push (c , params )
75+ if err != nil {
76+ return c .Status (code ).JSON (failed (code , err .Error ()))
77+ } else {
78+ return c .JSON (success ())
79+ }
6580}
6681
6782func routeDoPushV2 (c * fiber.Ctx ) error {
@@ -74,10 +89,75 @@ func routeDoPushV2(c *fiber.Ctx) error {
7489 c .Request ().URI ().QueryArgs ().VisitAll (func (key , value []byte ) {
7590 params [strings .ToLower (string (key ))] = string (value )
7691 })
77- return push (c , params )
92+
93+ var deviceKeys []string
94+ // Get the device_keys array from params
95+ if keys , ok := params ["device_keys" ]; ok {
96+ switch keys := keys .(type ) {
97+ case string :
98+ deviceKeys = strings .Split (keys , "," )
99+ case []interface {}:
100+ for _ , key := range keys {
101+ deviceKeys = append (deviceKeys , fmt .Sprint (key ))
102+ }
103+ default :
104+ return c .Status (400 ).JSON (failed (400 , "invalid type for device_keys" ))
105+ }
106+ delete (params , "device_keys" )
107+ }
108+
109+ count := len (deviceKeys )
110+
111+ if count == 0 {
112+ // Single push
113+ code , err := push (c , params )
114+ if err != nil {
115+ return c .Status (code ).JSON (failed (code , err .Error ()))
116+ } else {
117+ return c .JSON (success ())
118+ }
119+ } else {
120+ // Batch push
121+ if count > maxBatchPushCount && maxBatchPushCount != - 1 {
122+ return c .Status (400 ).JSON (failed (400 , "batch push count exceeds the maximum limit: %d" , maxBatchPushCount ))
123+ }
124+
125+ var wg sync.WaitGroup
126+ result := make ([]map [string ]interface {}, count )
127+ var mu sync.Mutex
128+
129+ for i := 0 ; i < count ; i ++ {
130+ // Copy params
131+ newParams := make (map [string ]interface {})
132+ for k , v := range params {
133+ newParams [k ] = v
134+ }
135+ newParams ["device_key" ] = deviceKeys [i ]
136+
137+ wg .Add (1 )
138+ go func (i int , newParams map [string ]interface {}) {
139+ defer wg .Done ()
140+
141+ // Push
142+ code , err := push (c , newParams )
143+
144+ // Save result
145+ mu .Lock ()
146+ result [i ] = make (map [string ]interface {})
147+ if err != nil {
148+ result [i ]["message" ] = err .Error ()
149+ }
150+ result [i ]["code" ] = code
151+ result [i ]["device_key" ] = deviceKeys [i ]
152+ mu .Unlock ()
153+ }(i , newParams )
154+ }
155+ wg .Wait ()
156+ return c .JSON (data (result ))
157+ }
78158}
79159
80- func push (c * fiber.Ctx , params map [string ]interface {}) error {
160+ func push (c * fiber.Ctx , params map [string ]interface {}) ( int , error ) {
81161 // default value
82162 msg := apns.PushMessage {
83163 Body : "" ,
@@ -123,27 +203,27 @@ func push(c *fiber.Ctx, params map[string]interface{}) error {
123203 if subtitle := c .Params ("subtitle" ); subtitle != "" {
124204 str , err := url .QueryUnescape (subtitle )
125205 if err != nil {
126- return err
206+ return 500 , err
127207 }
128208 msg .Subtitle = str
129209 }
130210 if title := c .Params ("title" ); title != "" {
131211 str , err := url .QueryUnescape (title )
132212 if err != nil {
133- return err
213+ return 500 , err
134214 }
135215 msg .Title = str
136216 }
137217 if body := c .Params ("body" ); body != "" {
138218 str , err := url .QueryUnescape (body )
139219 if err != nil {
140- return err
220+ return 500 , err
141221 }
142222 msg .Body = str
143223 }
144224
145225 if msg .DeviceKey == "" {
146- return c . Status ( 400 ). JSON ( failed ( 400 , "device key is empty" ) )
226+ return 400 , fmt . Errorf ( "device key is empty" )
147227 }
148228
149229 if msg .Body == "" && msg .Title == "" && msg .Subtitle == "" {
@@ -152,13 +232,13 @@ func push(c *fiber.Ctx, params map[string]interface{}) error {
152232
153233 deviceToken , err := db .DeviceTokenByKey (msg .DeviceKey )
154234 if err != nil {
155- return c . Status ( 400 ). JSON ( failed ( 400 , "failed to get device token: %v" , err ) )
235+ return 400 , fmt . Errorf ( "failed to get device token: %v" , err )
156236 }
157237 msg .DeviceToken = deviceToken
158238
159239 err = apns .Push (& msg )
160240 if err != nil {
161- return c . Status ( 500 ). JSON ( failed ( 500 , "push failed: %v" , err ) )
241+ return 500 , fmt . Errorf ( "push failed: %v" , err )
162242 }
163- return c . JSON ( success ())
243+ return 200 , nil
164244}
0 commit comments